New guess_model_task() function (#614)
				
					
				
			Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		@ -1,6 +1,6 @@
 | 
				
			|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
 | 
					# Ultralytics YOLO 🚀, GPL-3.0 license
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__version__ = "8.0.18"
 | 
					__version__ = "8.0.19"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ultralytics.yolo.engine.model import YOLO
 | 
					from ultralytics.yolo.engine.model import YOLO
 | 
				
			||||||
from ultralytics.yolo.utils import ops
 | 
					from ultralytics.yolo.utils import ops
 | 
				
			||||||
 | 
				
			|||||||
@ -251,7 +251,7 @@ class ClassificationModel(BaseModel):
 | 
				
			|||||||
                 ch=3,
 | 
					                 ch=3,
 | 
				
			||||||
                 nc=1000,
 | 
					                 nc=1000,
 | 
				
			||||||
                 cutoff=10,
 | 
					                 cutoff=10,
 | 
				
			||||||
                 verbose=True):  # yaml, model, number of classes, cutoff index
 | 
					                 verbose=True):  # yaml, model, channels, number of classes, cutoff index, verbose flag
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
 | 
					        self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -457,3 +457,53 @@ def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)
 | 
				
			|||||||
            ch = []
 | 
					            ch = []
 | 
				
			||||||
        ch.append(c2)
 | 
					        ch.append(c2)
 | 
				
			||||||
    return nn.Sequential(*layers), sorted(save)
 | 
					    return nn.Sequential(*layers), sorted(save)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def guess_model_task(model):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Guess the task of a PyTorch model from its architecture or configuration.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        model (nn.Module) or (dict): PyTorch model or model configuration in YAML format.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        str: Task of the model ('detect', 'segment', 'classify').
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Raises:
 | 
				
			||||||
 | 
					        SyntaxError: If the task of the model could not be determined.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    cfg, task = None, None
 | 
				
			||||||
 | 
					    if isinstance(model, dict):
 | 
				
			||||||
 | 
					        cfg = model
 | 
				
			||||||
 | 
					    elif isinstance(model, nn.Module):  # PyTorch model
 | 
				
			||||||
 | 
					        for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
 | 
				
			||||||
 | 
					            with contextlib.suppress(Exception):
 | 
				
			||||||
 | 
					                cfg = eval(x)
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Guess from YAML dictionary
 | 
				
			||||||
 | 
					    if cfg:
 | 
				
			||||||
 | 
					        m = cfg["head"][-1][-2].lower()  # output module name
 | 
				
			||||||
 | 
					        if m in ["classify", "classifier", "cls", "fc"]:
 | 
				
			||||||
 | 
					            task = "classify"
 | 
				
			||||||
 | 
					        if m in ["detect"]:
 | 
				
			||||||
 | 
					            task = "detect"
 | 
				
			||||||
 | 
					        if m in ["segment"]:
 | 
				
			||||||
 | 
					            task = "segment"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Guess from PyTorch model
 | 
				
			||||||
 | 
					    if task is None and isinstance(model, nn.Module):
 | 
				
			||||||
 | 
					        for m in model.modules():
 | 
				
			||||||
 | 
					            if isinstance(m, Detect):
 | 
				
			||||||
 | 
					                task = "detect"
 | 
				
			||||||
 | 
					            elif isinstance(m, Segment):
 | 
				
			||||||
 | 
					                task = "segment"
 | 
				
			||||||
 | 
					            elif isinstance(m, Classify):
 | 
				
			||||||
 | 
					                task = "classify"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Unable to determine task from model
 | 
				
			||||||
 | 
					    if task is None:
 | 
				
			||||||
 | 
					        raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
 | 
				
			||||||
 | 
					                          "i.e. 'task=detect', 'task=segment' or 'task=classify'.")
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return task
 | 
				
			||||||
 | 
				
			|||||||
@ -66,7 +66,7 @@ import torch
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import ultralytics
 | 
					import ultralytics
 | 
				
			||||||
from ultralytics.nn.modules import Detect, Segment
 | 
					from ultralytics.nn.modules import Detect, Segment
 | 
				
			||||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
 | 
					from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, guess_model_task
 | 
				
			||||||
from ultralytics.yolo.cfg import get_cfg
 | 
					from ultralytics.yolo.cfg import get_cfg
 | 
				
			||||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
 | 
					from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
 | 
				
			||||||
from ultralytics.yolo.data.utils import check_det_dataset
 | 
					from ultralytics.yolo.data.utils import check_det_dataset
 | 
				
			||||||
@ -74,7 +74,7 @@ from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get
 | 
				
			|||||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
 | 
					from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
 | 
				
			||||||
from ultralytics.yolo.utils.files import file_size
 | 
					from ultralytics.yolo.utils.files import file_size
 | 
				
			||||||
from ultralytics.yolo.utils.ops import Profile
 | 
					from ultralytics.yolo.utils.ops import Profile
 | 
				
			||||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, select_device, smart_inference_mode
 | 
					from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
MACOS = platform.system() == 'Darwin'  # macOS environment
 | 
					MACOS = platform.system() == 'Darwin'  # macOS environment
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -235,7 +235,7 @@ class Exporter:
 | 
				
			|||||||
        # Finish
 | 
					        # Finish
 | 
				
			||||||
        f = [str(x) for x in f if x]  # filter out '' and None
 | 
					        f = [str(x) for x in f if x]  # filter out '' and None
 | 
				
			||||||
        if any(f):
 | 
					        if any(f):
 | 
				
			||||||
            task = guess_task_from_model_yaml(model)
 | 
					            task = guess_model_task(model)
 | 
				
			||||||
            s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
 | 
					            s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
 | 
				
			||||||
            LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
 | 
					            LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
 | 
				
			||||||
                        f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
 | 
					                        f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
 | 
				
			||||||
 | 
				
			|||||||
@ -3,12 +3,13 @@
 | 
				
			|||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ultralytics import yolo  # noqa
 | 
					from ultralytics import yolo  # noqa
 | 
				
			||||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
 | 
					from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
 | 
				
			||||||
 | 
					                                  guess_model_task)
 | 
				
			||||||
from ultralytics.yolo.cfg import get_cfg
 | 
					from ultralytics.yolo.cfg import get_cfg
 | 
				
			||||||
from ultralytics.yolo.engine.exporter import Exporter
 | 
					from ultralytics.yolo.engine.exporter import Exporter
 | 
				
			||||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, yaml_load
 | 
					from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, yaml_load
 | 
				
			||||||
from ultralytics.yolo.utils.checks import check_yaml
 | 
					from ultralytics.yolo.utils.checks import check_yaml
 | 
				
			||||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, smart_inference_mode
 | 
					from ultralytics.yolo.utils.torch_utils import smart_inference_mode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Map head to model, trainer, validator, and predictor classes
 | 
					# Map head to model, trainer, validator, and predictor classes
 | 
				
			||||||
MODEL_MAP = {
 | 
					MODEL_MAP = {
 | 
				
			||||||
@ -73,9 +74,9 @@ class YOLO:
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        cfg = check_yaml(cfg)  # check YAML
 | 
					        cfg = check_yaml(cfg)  # check YAML
 | 
				
			||||||
        cfg_dict = yaml_load(cfg, append_filename=True)  # model dict
 | 
					        cfg_dict = yaml_load(cfg, append_filename=True)  # model dict
 | 
				
			||||||
        self.task = guess_task_from_model_yaml(cfg_dict)
 | 
					        self.task = guess_model_task(cfg_dict)
 | 
				
			||||||
        self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
 | 
					        self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
 | 
				
			||||||
            self._guess_ops_from_task(self.task)
 | 
					            self._assign_ops_from_task(self.task)
 | 
				
			||||||
        self.model = self.ModelClass(cfg_dict, verbose=verbose)  # initialize
 | 
					        self.model = self.ModelClass(cfg_dict, verbose=verbose)  # initialize
 | 
				
			||||||
        self.cfg = cfg
 | 
					        self.cfg = cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -92,7 +93,7 @@ class YOLO:
 | 
				
			|||||||
        self.overrides = self.model.args
 | 
					        self.overrides = self.model.args
 | 
				
			||||||
        self._reset_ckpt_args(self.overrides)
 | 
					        self._reset_ckpt_args(self.overrides)
 | 
				
			||||||
        self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
 | 
					        self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
 | 
				
			||||||
            self._guess_ops_from_task(self.task)
 | 
					            self._assign_ops_from_task(self.task)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def reset(self):
 | 
					    def reset(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@ -217,7 +218,7 @@ class YOLO:
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        self.model.to(device)
 | 
					        self.model.to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _guess_ops_from_task(self, task):
 | 
					    def _assign_ops_from_task(self, task):
 | 
				
			||||||
        model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
 | 
					        model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
 | 
				
			||||||
        # warning: eval is unsafe. Use with caution
 | 
					        # warning: eval is unsafe. Use with caution
 | 
				
			||||||
        trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
 | 
					        trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,6 @@
 | 
				
			|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
 | 
					# Ultralytics YOLO 🚀, GPL-3.0 license
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
Auto-batch utils
 | 
					AutoBatch utils
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from copy import deepcopy
 | 
					from copy import deepcopy
 | 
				
			||||||
 | 
				
			|||||||
@ -308,23 +308,6 @@ def strip_optimizer(f='best.pt', s=''):
 | 
				
			|||||||
    LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
 | 
					    LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def guess_task_from_model_yaml(model):
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        cfg = model if isinstance(model, dict) else model.yaml  # model cfg dict
 | 
					 | 
				
			||||||
        m = cfg["head"][-1][-2].lower()  # output module name
 | 
					 | 
				
			||||||
        task = None
 | 
					 | 
				
			||||||
        if m in ["classify", "classifier", "cls", "fc"]:
 | 
					 | 
				
			||||||
            task = "classify"
 | 
					 | 
				
			||||||
        if m in ["detect"]:
 | 
					 | 
				
			||||||
            task = "detect"
 | 
					 | 
				
			||||||
        if m in ["segment"]:
 | 
					 | 
				
			||||||
            task = "segment"
 | 
					 | 
				
			||||||
    except Exception as e:
 | 
					 | 
				
			||||||
        raise SyntaxError('Unknown task. Define task explicitly, i.e. task=detect when running your command. '
 | 
					 | 
				
			||||||
                          'Valid tasks are detect, segment, classify.') from e
 | 
					 | 
				
			||||||
    return task
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def profile(input, ops, n=10, device=None):
 | 
					def profile(input, ops, n=10, device=None):
 | 
				
			||||||
    """ YOLOv8 speed/memory/FLOPs profiler
 | 
					    """ YOLOv8 speed/memory/FLOPs profiler
 | 
				
			||||||
    Usage:
 | 
					    Usage:
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user