`build_optimizer()` assign all parameters (#2855)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 441e67d330
commit 61fa5efe6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -618,15 +618,19 @@ class BaseTrainer:
Returns: Returns:
optimizer (torch.optim.Optimizer): the built optimizer optimizer (torch.optim.Optimizer): the built optimizer
""" """
g = [], [], [] # optimizer parameter groups g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
for v in model.modules():
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay) for module_name, module in model.named_modules():
g[2].append(v.bias) for param_name, param in module.named_parameters(recurse=False):
if isinstance(v, bn): # weight (no decay) fullname = f'{module_name}.{param_name}' if module_name else param_name
g[1].append(v.weight) if 'bias' in fullname: # bias (no decay)
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay) g[2].append(param)
g[0].append(v.weight) elif isinstance(module, bn): # weight (no decay)
g[1].append(param)
else: # weight (with decay)
g[0].append(param)
if name == 'Adam': if name == 'Adam':
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum

Loading…
Cancel
Save