From 59d43356644dd0df056e8d25cd5be92e0a1db88f Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 25 Jan 2023 02:24:28 +0100 Subject: [PATCH] New `guess_model_task()` function (#614) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/__init__.py | 2 +- ultralytics/nn/tasks.py | 52 ++++++++++++++++++++++++++- ultralytics/yolo/engine/exporter.py | 6 ++-- ultralytics/yolo/engine/model.py | 13 +++---- ultralytics/yolo/utils/autobatch.py | 2 +- ultralytics/yolo/utils/torch_utils.py | 17 --------- 6 files changed, 63 insertions(+), 29 deletions(-) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index ba4263f..f429865 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = "8.0.18" +__version__ = "8.0.19" from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.utils import ops diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 68e6f66..2f18597 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -251,7 +251,7 @@ class ClassificationModel(BaseModel): ch=3, nc=1000, cutoff=10, - verbose=True): # yaml, model, number of classes, cutoff index + verbose=True): # yaml, model, channels, number of classes, cutoff index, verbose flag super().__init__() self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose) @@ -457,3 +457,53 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) ch = [] ch.append(c2) return nn.Sequential(*layers), sorted(save) + + +def guess_model_task(model): + """ + Guess the task of a PyTorch model from its architecture or configuration. + + Args: + model (nn.Module) or (dict): PyTorch model or model configuration in YAML format. + + Returns: + str: Task of the model ('detect', 'segment', 'classify'). + + Raises: + SyntaxError: If the task of the model could not be determined. + """ + cfg, task = None, None + if isinstance(model, dict): + cfg = model + elif isinstance(model, nn.Module): # PyTorch model + for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml': + with contextlib.suppress(Exception): + cfg = eval(x) + break + + # Guess from YAML dictionary + if cfg: + m = cfg["head"][-1][-2].lower() # output module name + if m in ["classify", "classifier", "cls", "fc"]: + task = "classify" + if m in ["detect"]: + task = "detect" + if m in ["segment"]: + task = "segment" + + # Guess from PyTorch model + if task is None and isinstance(model, nn.Module): + for m in model.modules(): + if isinstance(m, Detect): + task = "detect" + elif isinstance(m, Segment): + task = "segment" + elif isinstance(m, Classify): + task = "classify" + + # Unable to determine task from model + if task is None: + raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, " + "i.e. 'task=detect', 'task=segment' or 'task=classify'.") + else: + return task diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index 698a908..8a50a0e 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -66,7 +66,7 @@ import torch import ultralytics from ultralytics.nn.modules import Detect, Segment -from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, guess_model_task from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages from ultralytics.yolo.data.utils import check_det_dataset @@ -74,7 +74,7 @@ from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml from ultralytics.yolo.utils.files import file_size from ultralytics.yolo.utils.ops import Profile -from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, select_device, smart_inference_mode +from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode MACOS = platform.system() == 'Darwin' # macOS environment @@ -235,7 +235,7 @@ class Exporter: # Finish f = [str(x) for x in f if x] # filter out '' and None if any(f): - task = guess_task_from_model_yaml(model) + task = guess_model_task(model) s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models" LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' f"\nResults saved to {colorstr('bold', file.parent.resolve())}" diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index ef65898..0a5664b 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -3,12 +3,13 @@ from pathlib import Path from ultralytics import yolo # noqa -from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight +from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight, + guess_model_task) from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, yaml_load from ultralytics.yolo.utils.checks import check_yaml -from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, smart_inference_mode +from ultralytics.yolo.utils.torch_utils import smart_inference_mode # Map head to model, trainer, validator, and predictor classes MODEL_MAP = { @@ -73,9 +74,9 @@ class YOLO: """ cfg = check_yaml(cfg) # check YAML cfg_dict = yaml_load(cfg, append_filename=True) # model dict - self.task = guess_task_from_model_yaml(cfg_dict) + self.task = guess_model_task(cfg_dict) self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \ - self._guess_ops_from_task(self.task) + self._assign_ops_from_task(self.task) self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize self.cfg = cfg @@ -92,7 +93,7 @@ class YOLO: self.overrides = self.model.args self._reset_ckpt_args(self.overrides) self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \ - self._guess_ops_from_task(self.task) + self._assign_ops_from_task(self.task) def reset(self): """ @@ -217,7 +218,7 @@ class YOLO: """ self.model.to(device) - def _guess_ops_from_task(self, task): + def _assign_ops_from_task(self, task): model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task] # warning: eval is unsafe. Use with caution trainer_class = eval(train_lit.replace("TYPE", f"{self.type}")) diff --git a/ultralytics/yolo/utils/autobatch.py b/ultralytics/yolo/utils/autobatch.py index cac167d..fc4f6a2 100644 --- a/ultralytics/yolo/utils/autobatch.py +++ b/ultralytics/yolo/utils/autobatch.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license """ -Auto-batch utils +AutoBatch utils """ from copy import deepcopy diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index e8ea557..280ad4f 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -308,23 +308,6 @@ def strip_optimizer(f='best.pt', s=''): LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") -def guess_task_from_model_yaml(model): - try: - cfg = model if isinstance(model, dict) else model.yaml # model cfg dict - m = cfg["head"][-1][-2].lower() # output module name - task = None - if m in ["classify", "classifier", "cls", "fc"]: - task = "classify" - if m in ["detect"]: - task = "detect" - if m in ["segment"]: - task = "segment" - except Exception as e: - raise SyntaxError('Unknown task. Define task explicitly, i.e. task=detect when running your command. ' - 'Valid tasks are detect, segment, classify.') from e - return task - - def profile(input, ops, n=10, device=None): """ YOLOv8 speed/memory/FLOPs profiler Usage: