From 15c90bd40473c9d9a1cf82390bf7b94f9a88c3e8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Jun 2023 02:31:05 +0200 Subject: [PATCH] Default classify training to `pretrained=True` (#3239) --- ultralytics/yolo/cfg/default.yaml | 2 +- ultralytics/yolo/v8/classify/train.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/ultralytics/yolo/cfg/default.yaml b/ultralytics/yolo/cfg/default.yaml index 24c9b17..62202b5 100644 --- a/ultralytics/yolo/cfg/default.yaml +++ b/ultralytics/yolo/cfg/default.yaml @@ -19,7 +19,7 @@ workers: 8 # (int) number of worker threads for data loading (per RANK if DDP) project: # (str, optional) project name name: # (str, optional) experiment name, results saved to 'project/name' directory exist_ok: False # (bool) whether to overwrite existing experiment -pretrained: False # (bool) whether to use a pretrained model +pretrained: True # (bool) whether to use a pretrained model optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto] verbose: True # (bool) whether to print verbose output seed: 0 # (int) random seed for reproducibility diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 2949644..72feb55 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -33,9 +33,8 @@ class ClassificationTrainer(BaseTrainer): if weights: model.load(weights) - pretrained = self.args.pretrained for m in model.modules(): - if not pretrained and hasattr(m, 'reset_parameters'): + if not self.args.pretrained and hasattr(m, 'reset_parameters'): m.reset_parameters() if isinstance(m, torch.nn.Dropout) and self.args.dropout: m.p = self.args.dropout # set dropout @@ -61,8 +60,7 @@ class ClassificationTrainer(BaseTrainer): elif model.endswith('.yaml'): self.model = self.get_model(cfg=model) elif model in torchvision.models.__dict__: - pretrained = True - self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) + self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None) else: FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') ClassificationModel.reshape_outputs(self.model, self.data['nc'])