Default classify training to pretrained=True (#3239)

This commit is contained in:
Glenn Jocher
2023-06-18 02:31:05 +02:00
committed by GitHub
parent e78fb683f4
commit 15c90bd404
2 changed files with 3 additions and 5 deletions

View File

@ -33,9 +33,8 @@ class ClassificationTrainer(BaseTrainer):
if weights:
model.load(weights)
pretrained = self.args.pretrained
for m in model.modules():
if not pretrained and hasattr(m, 'reset_parameters'):
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
m.p = self.args.dropout # set dropout
@ -61,8 +60,7 @@ class ClassificationTrainer(BaseTrainer):
elif model.endswith('.yaml'):
self.model = self.get_model(cfg=model)
elif model in torchvision.models.__dict__:
pretrained = True
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
else:
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])