|
|
@ -618,15 +618,19 @@ class BaseTrainer:
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
optimizer (torch.optim.Optimizer): the built optimizer
|
|
|
|
optimizer (torch.optim.Optimizer): the built optimizer
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
g = [], [], [] # optimizer parameter groups
|
|
|
|
g = [], [], [] # optimizer parameter groups
|
|
|
|
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
|
|
|
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
|
|
|
for v in model.modules():
|
|
|
|
|
|
|
|
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
|
|
|
|
for module_name, module in model.named_modules():
|
|
|
|
g[2].append(v.bias)
|
|
|
|
for param_name, param in module.named_parameters(recurse=False):
|
|
|
|
if isinstance(v, bn): # weight (no decay)
|
|
|
|
fullname = f'{module_name}.{param_name}' if module_name else param_name
|
|
|
|
g[1].append(v.weight)
|
|
|
|
if 'bias' in fullname: # bias (no decay)
|
|
|
|
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
|
|
|
|
g[2].append(param)
|
|
|
|
g[0].append(v.weight)
|
|
|
|
elif isinstance(module, bn): # weight (no decay)
|
|
|
|
|
|
|
|
g[1].append(param)
|
|
|
|
|
|
|
|
else: # weight (with decay)
|
|
|
|
|
|
|
|
g[0].append(param)
|
|
|
|
|
|
|
|
|
|
|
|
if name == 'Adam':
|
|
|
|
if name == 'Adam':
|
|
|
|
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
|
|
|
|
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
|
|
|
|