From 832ea56eb4780e6f81b87aab3d6b2b40a8aab1e0 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 15 Nov 2022 20:06:29 +0530 Subject: [PATCH] update model initialization design, supports custom data/num_classes (#44) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/ci.yaml | 2 +- .gitignore | 3 +- ultralytics/yolo/engine/trainer.py | 24 +++++++-------- ultralytics/yolo/utils/configs/default.yaml | 4 +-- ultralytics/yolo/utils/downloads.py | 5 --- ultralytics/yolo/utils/modeling/tasks.py | 34 +++++++++++++++++---- ultralytics/yolo/v8/classify/train.py | 21 +++++++------ ultralytics/yolo/v8/segment/train.py | 18 +++++++---- 8 files changed, 67 insertions(+), 44 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 581a7d0..ae23219 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -94,7 +94,7 @@ jobs: - name: Test segmentation shell: bash # for Windows compatibility run: | - python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64 + python ultralytics/yolo/v8/segment/train.py model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64 - name: Test classification shell: bash # for Windows compatibility run: | diff --git a/.gitignore b/.gitignore index 75b4369..ed06157 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,5 @@ dmypy.json # datasets and projects datasets/ -ultralytics-yolo/ \ No newline at end of file +ultralytics-yolo/ +runs/ \ No newline at end of file diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 758f227..c48da2b 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -63,10 +63,8 @@ class BaseTrainer: else: self.data = check_dataset(self.data) self.trainset, self.testset = self.get_dataset(self.data) - if self.args.cfg is not None: - self.model = self.load_cfg(check_file(self.args.cfg)) - if self.args.model is not None: - self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device) + if self.args.model: + self.model = self.get_model(self.args.model, self.data) # epoch level metrics self.metrics = {} # handle metrics returned by validator @@ -261,20 +259,20 @@ class BaseTrainer: """ return data["train"], data["val"] - def get_model(self, model, pretrained): + def get_model(self, model: str, data: Dict): """ load/create/download model for any task """ - model = get_model(model) - for m in model.modules(): - if not pretrained and hasattr(m, 'reset_parameters'): - m.reset_parameters() - for p in model.parameters(): - p.requires_grad = True - + pretrained = False + if not str(model).endswith(".yaml"): + pretrained = True + weights = get_model(model) # rename this to something less confusing? + model = self.load_model(model_cfg=model if not pretrained else None, + weights=weights if pretrained else None, + data=self.data) return model - def load_cfg(self, cfg): + def load_model(self, model_cfg, weights, data): raise NotImplementedError("This task trainer doesn't support loading cfg files") def get_validator(self): diff --git a/ultralytics/yolo/utils/configs/default.yaml b/ultralytics/yolo/utils/configs/default.yaml index b85c63d..a1887bd 100644 --- a/ultralytics/yolo/utils/configs/default.yaml +++ b/ultralytics/yolo/utils/configs/default.yaml @@ -3,8 +3,7 @@ # Train settings ------------------------------------------------------------------------------------------------------- -model: null # i.e. yolov5s.pt -cfg: null # i.e. yolov5s.yaml +model: null # i.e. yolov5s.pt, yolo.yaml data: null # i.e. coco128.yaml epochs: 300 batch_size: 16 @@ -70,6 +69,7 @@ mosaic: 1.0 # image mosaic (probability) mixup: 0.0 # image mixup (probability) copy_paste: 0.0 # segment copy-paste (probability) label_smoothing: 0.0 +# anchors: 3 # Hydra configs -------------------------------------------------------------------------------------------------------- hydra: diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py index 1b09a3e..71fa63d 100644 --- a/ultralytics/yolo/utils/downloads.py +++ b/ultralytics/yolo/utils/downloads.py @@ -140,8 +140,3 @@ def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1 else: for u in [url] if isinstance(url, (str, Path)) else url: download_one(u, dir) - - -def get_model(model: str): - # check for local weights - pass diff --git a/ultralytics/yolo/utils/modeling/tasks.py b/ultralytics/yolo/utils/modeling/tasks.py index 0cbeb45..c6c82b5 100644 --- a/ultralytics/yolo/utils/modeling/tasks.py +++ b/ultralytics/yolo/utils/modeling/tasks.py @@ -66,7 +66,7 @@ class BaseModel(nn.Module): return self def load(self, weights): - # Force all tasks implement this function + # Force all tasks to implement this function raise NotImplementedError("This function needs to be implemented by derived classes!") @@ -169,10 +169,10 @@ class DetectionModel(BaseModel): mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) def load(self, weights): - ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak - csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 + csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32 csd = intersect_state_dicts(csd, self.state_dict()) # intersect self.load_state_dict(csd, strict=False) # load + LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from {weights}') class SegmentationModel(DetectionModel): @@ -203,11 +203,33 @@ class ClassificationModel(BaseModel): self.nc = nc def _from_yaml(self, cfg): - # Create a YOLOv5 classification model from a *.yaml file + # TODO: Create a YOLOv5 classification model from a *.yaml file self.model = None def load(self, weights): - ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak - csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 + model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts + csd = model.float().state_dict() csd = intersect_state_dicts(csd, self.state_dict()) # intersect self.load_state_dict(csd, strict=False) # load + + @staticmethod + def reshape_outputs(model, nc): + # Update a TorchVision classification model to class count 'n' if required + from ultralytics.yolo.utils.modeling.modules import Classify + name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module + if isinstance(m, Classify): # YOLO Classify() head + if m.linear.out_features != nc: + m.linear = nn.Linear(m.linear.in_features, nc) + elif isinstance(m, nn.Linear): # ResNet, EfficientNet + if m.out_features != nc: + setattr(model, name, nn.Linear(m.in_features, nc)) + elif isinstance(m, nn.Sequential): + types = [type(x) for x in m] + if nn.Linear in types: + i = types.index(nn.Linear) # nn.Linear index + if m[i].out_features != nc: + m[i] = nn.Linear(m[i].in_features, nc) + elif nn.Conv2d in types: + i = types.index(nn.Conv2d) # nn.Conv2d index + if m[i].out_channels != nc: + m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias) diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 4027d87..4037b83 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -1,26 +1,27 @@ -import subprocess -import time -from pathlib import Path - import hydra import torch from ultralytics.yolo import v8 from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer -from ultralytics.yolo.utils import colorstr -from ultralytics.yolo.utils.downloads import download -from ultralytics.yolo.utils.files import WorkingDirectory -from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first +from ultralytics.yolo.utils.modeling.tasks import ClassificationModel -# BaseTrainer python usage class ClassificationTrainer(BaseTrainer): + def load_model(self, model_cfg, weights, data): + # TODO: why treat clf models as unique. We should have clf yamls? + if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision + model = weights + else: + model = ClassificationModel(model_cfg, weights, data["nc"]) + ClassificationModel.reshape_outputs(model, data["nc"]) + return model + def get_dataloader(self, dataset_path, batch_size=None, rank=0): return build_classification_dataloader(path=dataset_path, imgsz=self.args.img_size, - batch_size=self.args.batch_size, + batch_size=batch_size, rank=rank) def preprocess_batch(self, batch): diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 548bae6..1dd64c8 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -10,12 +10,11 @@ import torch.nn.functional as F from ultralytics.yolo import v8 from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer -from ultralytics.yolo.utils.downloads import download -from ultralytics.yolo.utils.files import WorkingDirectory +from ultralytics.yolo.utils.anchors import check_anchors from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE from ultralytics.yolo.utils.modeling.tasks import SegmentationModel from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy -from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, de_parallel, torch_distributed_zero_first +from ultralytics.yolo.utils.torch_utils import de_parallel # BaseTrainer python usage @@ -45,8 +44,15 @@ class SegmentationTrainer(BaseTrainer): batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 return batch - def load_cfg(self, cfg): - return SegmentationModel(cfg, nc=80) + def load_model(self, model_cfg, weights, data): + model = SegmentationModel(model_cfg if model_cfg else weights["model"].yaml, + ch=3, + nc=data["nc"], + anchors=self.args.get("anchors")) + check_anchors(model, self.args.anchor_t, self.args.img_size) + if weights: + model.load(weights) + return model def get_validator(self): return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console) @@ -232,7 +238,7 @@ class SegmentationTrainer(BaseTrainer): @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) def train(cfg): - cfg.cfg = v8.ROOT / "models/yolov5n-seg.yaml" + cfg.model = v8.ROOT / "models/yolov5n-seg.yaml" cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist") trainer = SegmentationTrainer(cfg) trainer.train()