Default classify training to `pretrained=True` (#3239)

single_channel
Glenn Jocher 1 year ago committed by GitHub
parent e78fb683f4
commit 15c90bd404
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,7 +19,7 @@ workers: 8 # (int) number of worker threads for data loading (per RANK if DDP)
project: # (str, optional) project name
name: # (str, optional) experiment name, results saved to 'project/name' directory
exist_ok: False # (bool) whether to overwrite existing experiment
pretrained: False # (bool) whether to use a pretrained model
pretrained: True # (bool) whether to use a pretrained model
optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
verbose: True # (bool) whether to print verbose output
seed: 0 # (int) random seed for reproducibility

@ -33,9 +33,8 @@ class ClassificationTrainer(BaseTrainer):
if weights:
model.load(weights)
pretrained = self.args.pretrained
for m in model.modules():
if not pretrained and hasattr(m, 'reset_parameters'):
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
m.p = self.args.dropout # set dropout
@ -61,8 +60,7 @@ class ClassificationTrainer(BaseTrainer):
elif model.endswith('.yaml'):
self.model = self.get_model(cfg=model)
elif model in torchvision.models.__dict__:
pretrained = True
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
else:
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])

Loading…
Cancel
Save