|
|
|
@ -33,9 +33,8 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
|
|
if weights:
|
|
|
|
|
model.load(weights)
|
|
|
|
|
|
|
|
|
|
pretrained = self.args.pretrained
|
|
|
|
|
for m in model.modules():
|
|
|
|
|
if not pretrained and hasattr(m, 'reset_parameters'):
|
|
|
|
|
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
|
|
|
|
|
m.reset_parameters()
|
|
|
|
|
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
|
|
|
|
m.p = self.args.dropout # set dropout
|
|
|
|
@ -61,8 +60,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
|
|
elif model.endswith('.yaml'):
|
|
|
|
|
self.model = self.get_model(cfg=model)
|
|
|
|
|
elif model in torchvision.models.__dict__:
|
|
|
|
|
pretrained = True
|
|
|
|
|
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
|
|
|
|
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
|
|
|
|
|
else:
|
|
|
|
|
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
|
|
|
|
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
|
|
|
|