`ultralytics 8.0.143` add `Model` base class (#3934)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Laughing 1 year ago committed by GitHub
parent 3c787eb080
commit 1a0eb3f099
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ description: Explore the detailed guide on using the Ultralytics YOLO Engine Mod
keywords: Ultralytics, YOLO, engine model, documentation, guide, implementation, training, evaluation keywords: Ultralytics, YOLO, engine model, documentation, guide, implementation, training, evaluation
--- ---
## YOLO ## Model
--- ---
### ::: ultralytics.engine.model.YOLO ### ::: ultralytics.engine.model.Model
<br><br> <br><br>

@ -0,0 +1,9 @@
---
description: Discover the Ultralytics YOLO model class. Learn advanced techniques, tips, and tricks for training.
keywords: Ultralytics YOLO, YOLO, YOLO model, Model Training, Machine Learning, Deep Learning, Computer Vision
---
## YOLO
---
### ::: ultralytics.models.yolo.model.YOLO
<br><br>

@ -317,6 +317,7 @@ nav:
- predict: reference/models/yolo/detect/predict.md - predict: reference/models/yolo/detect/predict.md
- train: reference/models/yolo/detect/train.md - train: reference/models/yolo/detect/train.md
- val: reference/models/yolo/detect/val.md - val: reference/models/yolo/detect/val.md
- model: reference/models/yolo/model.md
- pose: - pose:
- predict: reference/models/yolo/pose/predict.md - predict: reference/models/yolo/pose/predict.md
- train: reference/models/yolo/pose/train.md - train: reference/models/yolo/pose/train.md

@ -1,10 +1,9 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.142' __version__ = '8.0.143'
from ultralytics.engine.model import YOLO
from ultralytics.hub import start from ultralytics.hub import start
from ultralytics.models import RTDETR, SAM from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM from ultralytics.models.fastsam import FastSAM
from ultralytics.models.nas import NAS from ultralytics.models.nas import NAS
from ultralytics.utils import SETTINGS as settings from ultralytics.utils import SETTINGS as settings

@ -1,36 +1,23 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
import inspect
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from ultralytics.cfg import get_cfg from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter from ultralytics.engine.exporter import Exporter
from ultralytics.models import yolo # noqa from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel,
attempt_load_one_weight, guess_model_task, nn, yaml_model_load)
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks, from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
is_git_dir, yaml_load) is_git_dir, yaml_load)
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
from ultralytics.utils.downloads import GITHUB_ASSET_STEMS from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.utils.torch_utils import smart_inference_mode from ultralytics.utils.torch_utils import smart_inference_mode
# Map head to model, trainer, validator, and predictor classes
TASK_MAP = {
'classify': [
ClassificationModel, yolo.classify.ClassificationTrainer, yolo.classify.ClassificationValidator,
yolo.classify.ClassificationPredictor],
'detect':
[DetectionModel, yolo.detect.DetectionTrainer, yolo.detect.DetectionValidator, yolo.detect.DetectionPredictor],
'segment': [
SegmentationModel, yolo.segment.SegmentationTrainer, yolo.segment.SegmentationValidator,
yolo.segment.SegmentationPredictor],
'pose': [PoseModel, yolo.pose.PoseTrainer, yolo.pose.PoseValidator, yolo.pose.PosePredictor]}
class Model:
class YOLO:
""" """
YOLO (You Only Look Once) object detection model. A base model class to unify apis for all the models.
Args: Args:
model (str, Path): Path to the model file to load or create. model (str, Path): Path to the model file to load or create.
@ -81,13 +68,13 @@ class YOLO:
self.predictor = None # reuse predictor self.predictor = None # reuse predictor
self.model = None # model object self.model = None # model object
self.trainer = None # trainer object self.trainer = None # trainer object
self.task = None # task type
self.ckpt = None # if loaded from *.pt self.ckpt = None # if loaded from *.pt
self.cfg = None # if loaded from *.yaml self.cfg = None # if loaded from *.yaml
self.ckpt_path = None self.ckpt_path = None
self.overrides = {} # overrides for trainer object self.overrides = {} # overrides for trainer object
self.metrics = None # validation/training metrics self.metrics = None # validation/training metrics
self.session = None # HUB session self.session = None # HUB session
self.task = task # task type
model = str(model).strip() # strip spaces model = str(model).strip() # strip spaces
# Check if Ultralytics HUB model from https://hub.ultralytics.com # Check if Ultralytics HUB model from https://hub.ultralytics.com
@ -109,11 +96,6 @@ class YOLO:
"""Calls the 'predict' function with given arguments to perform object detection.""" """Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs) return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@staticmethod @staticmethod
def is_hub_model(model): def is_hub_model(model):
"""Check if the provided model is a HUB model.""" """Check if the provided model is a HUB model."""
@ -122,19 +104,21 @@ class YOLO:
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
def _new(self, cfg: str, task=None, verbose=True): def _new(self, cfg: str, task=None, model=None, verbose=True):
""" """
Initializes a new model and infers the task type from the model definitions. Initializes a new model and infers the task type from the model definitions.
Args: Args:
cfg (str): model configuration file cfg (str): model configuration file
task (str | None): model task task (str | None): model task
model (BaseModel): Customized model.
verbose (bool): display model info on load verbose (bool): display model info on load
""" """
cfg_dict = yaml_model_load(cfg) cfg_dict = yaml_model_load(cfg)
self.cfg = cfg self.cfg = cfg
self.task = task or guess_model_task(cfg_dict) self.task = task or guess_model_task(cfg_dict)
self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model model = model or self.smart_load('model')
self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
self.overrides['model'] = self.cfg self.overrides['model'] = self.cfg
# Below added to allow export from yamls # Below added to allow export from yamls
@ -217,7 +201,7 @@ class YOLO:
self.model.fuse() self.model.fuse()
@smart_inference_mode() @smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs): def predict(self, source=None, stream=False, predictor=None, **kwargs):
""" """
Perform prediction using the YOLO model. Perform prediction using the YOLO model.
@ -225,6 +209,7 @@ class YOLO:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on. source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model. Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False. stream (bool): Whether to stream the predictions or not. Defaults to False.
predictor (BasePredictor): Customized predictor.
**kwargs : Additional keyword arguments passed to the predictor. **kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options. Check the 'configuration' section in the documentation for all available options.
@ -236,6 +221,8 @@ class YOLO:
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any( is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track')) x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
# Check prompts for SAM/FastSAM
prompts = kwargs.pop('prompts', None)
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides['conf'] = 0.25 overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs overrides.update(kwargs) # prefer kwargs
@ -245,12 +232,16 @@ class YOLO:
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
if not self.predictor: if not self.predictor:
self.task = overrides.get('task') or self.task self.task = overrides.get('task') or self.task
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks) predictor = predictor or self.smart_load('predictor')
self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=is_cli) self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides) self.predictor.args = get_cfg(self.predictor.args, overrides)
if 'project' in overrides or 'name' in overrides: if 'project' in overrides or 'name' in overrides:
self.predictor.save_dir = self.predictor.get_save_dir() self.predictor.save_dir = self.predictor.get_save_dir()
# Set prompts for SAM/FastSAM
if len and hasattr(self.predictor, 'set_prompts'):
self.predictor.set_prompts(prompts)
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(self, source=None, stream=False, persist=False, **kwargs): def track(self, source=None, stream=False, persist=False, **kwargs):
@ -277,12 +268,13 @@ class YOLO:
return self.predict(source=source, stream=stream, **kwargs) return self.predict(source=source, stream=stream, **kwargs)
@smart_inference_mode() @smart_inference_mode()
def val(self, data=None, **kwargs): def val(self, data=None, validator=None, **kwargs):
""" """
Validate a model on a given dataset. Validate a model on a given dataset.
Args: Args:
data (str): The dataset to validate on. Accepts all formats accepted by yolo data (str): The dataset to validate on. Accepts all formats accepted by yolo
validator (BaseValidator): Customized validator.
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
""" """
overrides = self.overrides.copy() overrides = self.overrides.copy()
@ -295,11 +287,12 @@ class YOLO:
self.task = args.task self.task = args.task
else: else:
args.task = self.task args.task = self.task
validator = validator or self.smart_load('validator')
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)): if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
args.imgsz = check_imgsz(args.imgsz, max_dim=1) args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks) validator = validator(args=args, _callbacks=self.callbacks)
validator(model=self.model) validator(model=self.model)
self.metrics = validator.metrics self.metrics = validator.metrics
@ -349,11 +342,12 @@ class YOLO:
args.task = self.task args.task = self.task
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
def train(self, **kwargs): def train(self, trainer=None, **kwargs):
""" """
Trains the model on a given dataset. Trains the model on a given dataset.
Args: Args:
trainer (BaseTrainer, optional): Customized trainer.
**kwargs (Any): Any number of arguments representing the training configuration. **kwargs (Any): Any number of arguments representing the training configuration.
""" """
self._check_is_pytorch_model() self._check_is_pytorch_model()
@ -373,7 +367,8 @@ class YOLO:
if overrides.get('resume'): if overrides.get('resume'):
overrides['resume'] = self.ckpt_path overrides['resume'] = self.ckpt_path
self.task = overrides.get('task') or self.task self.task = overrides.get('task') or self.task
self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks) trainer = trainer or self.smart_load('trainer')
self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
if not overrides.get('resume'): # manually set model only if not resuming if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model self.model = self.trainer.model
@ -442,3 +437,27 @@ class YOLO:
"""Reset all registered callbacks.""" """Reset all registered callbacks."""
for event in callbacks.default_callbacks.keys(): for event in callbacks.default_callbacks.keys():
self.callbacks[event] = [callbacks.default_callbacks[event][0]] self.callbacks[event] = [callbacks.default_callbacks[event][0]]
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def smart_load(self, key):
"""Load model/trainer/validator/predictor."""
try:
return self.task_map[self.task][key]
except Exception:
name = self.__class__.__name__
mode = inspect.stack()[1][3] # get the function name.
raise NotImplementedError(
f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')
@property
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes
Returns:
task_map (dict)
"""
raise NotImplementedError('Please provide task map for your model!')

@ -1,4 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .rtdetr import RTDETR from .rtdetr import RTDETR
from .sam import SAM from .sam import SAM
from .yolo import YOLO
__all__ = 'RTDETR', 'SAM' # allow simpler import __all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import

@ -1,4 +1,14 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
from pathlib import Path
from ultralytics.engine.model import Model
from .predict import FastSAMPredictor
from .val import FastSAMValidator
class FastSAM(Model):
""" """
FastSAM model interface. FastSAM model interface.
@ -9,103 +19,13 @@ Usage - Predict:
results = model.predict('ultralytics/assets/bus.jpg') results = model.predict('ultralytics/assets/bus.jpg')
""" """
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.engine.model import YOLO
from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import FastSAMPredictor
class FastSAM(YOLO):
def __init__(self, model='FastSAM-x.pt'): def __init__(self, model='FastSAM-x.pt'):
"""Call the __init__ method of the parent class (YOLO) with the updated default model""" """Call the __init__ method of the parent class (YOLO) with the updated default model"""
if model == 'FastSAM.pt': if model == 'FastSAM.pt':
model = 'FastSAM-x.pt' model = 'FastSAM-x.pt'
super().__init__(model=model) assert Path(model).suffix != '.yaml', 'FastSAM models only support pre-trained models.'
# any additional initialization code for FastSAM super().__init__(model=model, task='segment')
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs
overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides['mode'] in ['track', 'predict']
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
self.predictor = FastSAMPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model, verbose=False)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""Function trains models but raises an error as FastSAM models do not support training."""
raise NotImplementedError("FastSAM models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='segment', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = FastSAM(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr): @property
"""Raises error if object has no requested attribute.""" def task_map(self):
name = self.__class__.__name__ return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

@ -13,105 +13,36 @@ from pathlib import Path
import torch import torch
from ultralytics.cfg import get_cfg from ultralytics.engine.model import Model
from ultralytics.engine.exporter import Exporter
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import NASPredictor from .predict import NASPredictor
from .val import NASValidator from .val import NASValidator
class NAS: class NAS(Model):
def __init__(self, model='yolo_nas_s.pt') -> None: def __init__(self, model='yolo_nas_s.pt') -> None:
assert Path(model).suffix != '.yaml', 'YOLO-NAS models only support pre-trained models.'
super().__init__(model, task='detect')
@smart_inference_mode()
def _load(self, weights: str, task: str):
# Load or create new NAS model # Load or create new NAS model
import super_gradients import super_gradients
suffix = Path(weights).suffix
self.predictor = None
suffix = Path(model).suffix
if suffix == '.pt': if suffix == '.pt':
self._load(model) self.model = torch.load(weights)
elif suffix == '': elif suffix == '':
self.model = super_gradients.training.models.get(model, pretrained_weights='coco') self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
self.task = 'detect'
self.model.args = DEFAULT_CFG_DICT # attach args to model
# Standardize model # Standardize model
self.model.fuse = lambda verbose=True: self.model self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32]) self.model.stride = torch.tensor([32])
self.model.names = dict(enumerate(self.model._class_names)) self.model.names = dict(enumerate(self.model._class_names))
self.model.is_fused = lambda: False # for info() self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info() self.model.yaml = {} # for info()
self.model.pt_path = model # for export() self.model.pt_path = weights # for export()
self.model.task = 'detect' # for export() self.model.task = 'detect' # for export()
self.info()
@smart_inference_mode()
def _load(self, weights: str):
self.model = torch.load(weights)
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = dict(conf=0.25, task='detect', mode='predict')
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = NASPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""Function trains models but raises an error as NAS models do not support training."""
raise NotImplementedError("NAS models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='detect', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = NASValidator(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True): def info(self, detailed=False, verbose=True):
""" """
@ -123,11 +54,6 @@ class NAS:
""" """
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
def __call__(self, source=None, stream=False, **kwargs): @property
"""Calls the 'predict' function with given arguments to perform object detection.""" def task_map(self):
return self.predict(source, stream, **kwargs) return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

@ -2,172 +2,29 @@
""" """
RT-DETR model interface RT-DETR model interface
""" """
from ultralytics.engine.model import Model
from pathlib import Path from ultralytics.nn.tasks import RTDETRDetectionModel
import torch.nn as nn
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import RTDETRPredictor from .predict import RTDETRPredictor
from .train import RTDETRTrainer from .train import RTDETRTrainer
from .val import RTDETRValidator from .val import RTDETRValidator
class RTDETR: class RTDETR(Model):
"""
RTDETR model interface.
"""
def __init__(self, model='rtdetr-l.pt') -> None: def __init__(self, model='rtdetr-l.pt') -> None:
if model and not model.endswith('.pt') and not model.endswith('.yaml'): if model and not model.endswith('.pt') and not model.endswith('.yaml'):
raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.') raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.')
# Load or create new YOLO model super().__init__(model=model, task='detect')
self.predictor = None
self.ckpt = None @property
suffix = Path(model).suffix def task_map(self):
if suffix == '.yaml': return {
self._new(model) 'detect': {
else: 'predictor': RTDETRPredictor,
self._load(model) 'validator': RTDETRValidator,
'trainer': RTDETRTrainer,
def _new(self, cfg: str, verbose=True): 'model': RTDETRDetectionModel}}
cfg_dict = yaml_model_load(cfg)
self.cfg = cfg
self.task = 'detect'
self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model
# Below added to allow export from YAMLs
self.model.args = DEFAULT_CFG_DICT # attach args to model
self.model.task = self.task
@smart_inference_mode()
def _load(self, weights: str):
self.model, self.ckpt = attempt_load_one_weight(weights)
self.model.args = DEFAULT_CFG_DICT # attach args to model
self.task = self.model.args['task']
@smart_inference_mode()
def load(self, weights='yolov8n.pt'):
"""
Transfers parameters with matching names and shapes from 'weights' to model.
"""
if isinstance(weights, (str, Path)):
weights, self.ckpt = attempt_load_one_weight(weights)
self.model.load(weights)
return self
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = dict(conf=0.25, task='detect', mode='predict')
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = RTDETRPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""
Trains the model on a given dataset.
Args:
**kwargs (Any): Any number of arguments representing the training configuration.
"""
overrides = dict(task='detect', mode='train')
overrides.update(kwargs)
overrides['deterministic'] = False
if not overrides.get('data'):
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get('resume'):
overrides['resume'] = self.ckpt_path
self.task = overrides.get('task') or self.task
self.trainer = RTDETRTrainer(overrides=overrides)
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
# Update model and cfg after training
if RANK in (-1, 0):
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='detect', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = RTDETRValidator(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
def info(self, verbose=True):
"""Get model info"""
return model_info(self.model, verbose=verbose)
def _check_is_pytorch_model(self):
"""
Raises TypeError is model is not a PyTorch model
"""
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
pt_module = isinstance(self.model, nn.Module)
if not (pt_module or pt_str):
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
f'PyTorch models can be used to train, val, predict and export, i.e. '
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
def fuse(self):
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
self._check_is_pytorch_model()
self.model.fuse()
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

@ -3,51 +3,38 @@
SAM model interface SAM model interface
""" """
from ultralytics.cfg import get_cfg from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info from ultralytics.utils.torch_utils import model_info
from .build import build_sam from .build import build_sam
from .predict import Predictor from .predict import Predictor
class SAM: class SAM(Model):
"""
SAM model interface.
"""
def __init__(self, model='sam_b.pt') -> None: def __init__(self, model='sam_b.pt') -> None:
if model and not model.endswith('.pt') and not model.endswith('.pth'): if model and not model.endswith('.pt') and not model.endswith('.pth'):
# Should raise AssertionError instead? # Should raise AssertionError instead?
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint') raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
self.model = build_sam(model) super().__init__(model=model, task='segment')
self.task = 'segment' # required
self.predictor = None # reuse predictor def _load(self, weights: str, task=None):
self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs): def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Predicts and returns segmentation masks for given image or video source.""" """Predicts and returns segmentation masks for given image or video source."""
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024) overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
overrides.update(kwargs) # prefer kwargs kwargs.update(overrides)
if not self.predictor: prompts = dict(bboxes=bboxes, points=points, labels=labels)
self.predictor = Predictor(overrides=overrides) super().predict(source, stream, prompts=prompts, **kwargs)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream, bboxes=bboxes, points=points, labels=labels)
def train(self, **kwargs):
"""Function trains models but raises an error as SAM models do not support training."""
raise NotImplementedError("SAM models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
raise NotImplementedError("SAM models don't support validation")
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs): def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection.""" """Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, bboxes, points, labels, **kwargs) return self.predict(source, stream, bboxes, points, labels, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def info(self, detailed=False, verbose=True): def info(self, detailed=False, verbose=True):
""" """
Logs model info. Logs model info.
@ -57,3 +44,7 @@ class SAM:
verbose (bool): Controls verbosity. verbose (bool): Controls verbosity.
""" """
return model_info(self.model, detailed=detailed, verbose=verbose) return model_info(self.model, detailed=detailed, verbose=verbose)
@property
def task_map(self):
return {'segment': {'predictor': Predictor}}

@ -28,6 +28,8 @@ class Predictor(BasePredictor):
# Args for set_image # Args for set_image
self.im = None self.im = None
self.features = None self.features = None
# Args for set_prompts
self.prompts = {}
# Args for segment everything # Args for segment everything
self.segment_all = False self.segment_all = False
@ -92,6 +94,10 @@ class Predictor(BasePredictor):
of masks and H=W=256. These low resolution logits can be passed to of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input. a subsequent iteration as mask input.
""" """
# Get prompts from self.prompts first
bboxes = self.prompts.pop('bboxes', bboxes)
points = self.prompts.pop('points', points)
masks = self.prompts.pop('masks', masks)
if all(i is None for i in [bboxes, points, masks]): if all(i is None for i in [bboxes, points, masks]):
return self.generate(im, *args, **kwargs) return self.generate(im, *args, **kwargs)
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
@ -348,6 +354,10 @@ class Predictor(BasePredictor):
self.im = im self.im = im
break break
def set_prompts(self, prompts):
"""Set prompts in advance."""
self.prompts = prompts
def reset_image(self): def reset_image(self):
self.im = None self.im = None
self.features = None self.features = None

@ -2,4 +2,6 @@
from ultralytics.models.yolo import classify, detect, pose, segment from ultralytics.models.yolo import classify, detect, pose, segment
__all__ = 'classify', 'segment', 'detect', 'pose' from .model import YOLO
__all__ = 'classify', 'segment', 'detect', 'pose', 'YOLO'

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy from copy import copy
import numpy as np import numpy as np

@ -0,0 +1,36 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.engine.model import Model
from ultralytics.models import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel
class YOLO(Model):
"""
YOLO (You Only Look Once) object detection model.
"""
@property
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes"""
return {
'classify': {
'model': ClassificationModel,
'trainer': yolo.classify.ClassificationTrainer,
'validator': yolo.classify.ClassificationValidator,
'predictor': yolo.classify.ClassificationPredictor, },
'detect': {
'model': DetectionModel,
'trainer': yolo.detect.DetectionTrainer,
'validator': yolo.detect.DetectionValidator,
'predictor': yolo.detect.DetectionPredictor, },
'segment': {
'model': SegmentationModel,
'trainer': yolo.segment.SegmentationTrainer,
'validator': yolo.segment.SegmentationValidator,
'predictor': yolo.segment.SegmentationPredictor, },
'pose': {
'model': PoseModel,
'trainer': yolo.pose.PoseTrainer,
'validator': yolo.pose.PoseValidator,
'predictor': yolo.pose.PosePredictor, }, }

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy from copy import copy
from ultralytics.models import yolo from ultralytics.models import yolo

Loading…
Cancel
Save