`ultralytics 8.0.143` add `Model` base class (#3934)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Laughing 1 year ago committed by GitHub
parent 3c787eb080
commit 1a0eb3f099
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ description: Explore the detailed guide on using the Ultralytics YOLO Engine Mod
keywords: Ultralytics, YOLO, engine model, documentation, guide, implementation, training, evaluation
---
## YOLO
## Model
---
### ::: ultralytics.engine.model.YOLO
### ::: ultralytics.engine.model.Model
<br><br>

@ -0,0 +1,9 @@
---
description: Discover the Ultralytics YOLO model class. Learn advanced techniques, tips, and tricks for training.
keywords: Ultralytics YOLO, YOLO, YOLO model, Model Training, Machine Learning, Deep Learning, Computer Vision
---
## YOLO
---
### ::: ultralytics.models.yolo.model.YOLO
<br><br>

@ -317,6 +317,7 @@ nav:
- predict: reference/models/yolo/detect/predict.md
- train: reference/models/yolo/detect/train.md
- val: reference/models/yolo/detect/val.md
- model: reference/models/yolo/model.md
- pose:
- predict: reference/models/yolo/pose/predict.md
- train: reference/models/yolo/pose/train.md

@ -1,10 +1,9 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.142'
__version__ = '8.0.143'
from ultralytics.engine.model import YOLO
from ultralytics.hub import start
from ultralytics.models import RTDETR, SAM
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM
from ultralytics.models.nas import NAS
from ultralytics.utils import SETTINGS as settings

@ -1,36 +1,23 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import inspect
import sys
from pathlib import Path
from typing import Union
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.models import yolo # noqa
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel,
attempt_load_one_weight, guess_model_task, nn, yaml_model_load)
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
is_git_dir, yaml_load)
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.utils.torch_utils import smart_inference_mode
# Map head to model, trainer, validator, and predictor classes
TASK_MAP = {
'classify': [
ClassificationModel, yolo.classify.ClassificationTrainer, yolo.classify.ClassificationValidator,
yolo.classify.ClassificationPredictor],
'detect':
[DetectionModel, yolo.detect.DetectionTrainer, yolo.detect.DetectionValidator, yolo.detect.DetectionPredictor],
'segment': [
SegmentationModel, yolo.segment.SegmentationTrainer, yolo.segment.SegmentationValidator,
yolo.segment.SegmentationPredictor],
'pose': [PoseModel, yolo.pose.PoseTrainer, yolo.pose.PoseValidator, yolo.pose.PosePredictor]}
class YOLO:
class Model:
"""
YOLO (You Only Look Once) object detection model.
A base model class to unify apis for all the models.
Args:
model (str, Path): Path to the model file to load or create.
@ -81,13 +68,13 @@ class YOLO:
self.predictor = None # reuse predictor
self.model = None # model object
self.trainer = None # trainer object
self.task = None # task type
self.ckpt = None # if loaded from *.pt
self.cfg = None # if loaded from *.yaml
self.ckpt_path = None
self.overrides = {} # overrides for trainer object
self.metrics = None # validation/training metrics
self.session = None # HUB session
self.task = task # task type
model = str(model).strip() # strip spaces
# Check if Ultralytics HUB model from https://hub.ultralytics.com
@ -109,11 +96,6 @@ class YOLO:
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@staticmethod
def is_hub_model(model):
"""Check if the provided model is a HUB model."""
@ -122,19 +104,21 @@ class YOLO:
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
def _new(self, cfg: str, task=None, verbose=True):
def _new(self, cfg: str, task=None, model=None, verbose=True):
"""
Initializes a new model and infers the task type from the model definitions.
Args:
cfg (str): model configuration file
task (str | None): model task
model (BaseModel): Customized model.
verbose (bool): display model info on load
"""
cfg_dict = yaml_model_load(cfg)
self.cfg = cfg
self.task = task or guess_model_task(cfg_dict)
self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
model = model or self.smart_load('model')
self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
self.overrides['model'] = self.cfg
# Below added to allow export from yamls
@ -217,7 +201,7 @@ class YOLO:
self.model.fuse()
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
def predict(self, source=None, stream=False, predictor=None, **kwargs):
"""
Perform prediction using the YOLO model.
@ -225,6 +209,7 @@ class YOLO:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
predictor (BasePredictor): Customized predictor.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
@ -236,6 +221,8 @@ class YOLO:
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
# Check prompts for SAM/FastSAM
prompts = kwargs.pop('prompts', None)
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs
@ -245,12 +232,16 @@ class YOLO:
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
if not self.predictor:
self.task = overrides.get('task') or self.task
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
predictor = predictor or self.smart_load('predictor')
self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
if 'project' in overrides or 'name' in overrides:
self.predictor.save_dir = self.predictor.get_save_dir()
# Set prompts for SAM/FastSAM
if len and hasattr(self.predictor, 'set_prompts'):
self.predictor.set_prompts(prompts)
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(self, source=None, stream=False, persist=False, **kwargs):
@ -277,12 +268,13 @@ class YOLO:
return self.predict(source=source, stream=stream, **kwargs)
@smart_inference_mode()
def val(self, data=None, **kwargs):
def val(self, data=None, validator=None, **kwargs):
"""
Validate a model on a given dataset.
Args:
data (str): The dataset to validate on. Accepts all formats accepted by yolo
validator (BaseValidator): Customized validator.
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
overrides = self.overrides.copy()
@ -295,11 +287,12 @@ class YOLO:
self.task = args.task
else:
args.task = self.task
validator = validator or self.smart_load('validator')
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks)
validator = validator(args=args, _callbacks=self.callbacks)
validator(model=self.model)
self.metrics = validator.metrics
@ -349,11 +342,12 @@ class YOLO:
args.task = self.task
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
def train(self, **kwargs):
def train(self, trainer=None, **kwargs):
"""
Trains the model on a given dataset.
Args:
trainer (BaseTrainer, optional): Customized trainer.
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
@ -373,7 +367,8 @@ class YOLO:
if overrides.get('resume'):
overrides['resume'] = self.ckpt_path
self.task = overrides.get('task') or self.task
self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
trainer = trainer or self.smart_load('trainer')
self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
@ -442,3 +437,27 @@ class YOLO:
"""Reset all registered callbacks."""
for event in callbacks.default_callbacks.keys():
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def smart_load(self, key):
"""Load model/trainer/validator/predictor."""
try:
return self.task_map[self.task][key]
except Exception:
name = self.__class__.__name__
mode = inspect.stack()[1][3] # get the function name.
raise NotImplementedError(
f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')
@property
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes
Returns:
task_map (dict)
"""
raise NotImplementedError('Please provide task map for your model!')

@ -1,4 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .rtdetr import RTDETR
from .sam import SAM
from .yolo import YOLO
__all__ = 'RTDETR', 'SAM' # allow simpler import
__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import

@ -1,111 +1,31 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
FastSAM model interface.
Usage - Predict:
from ultralytics import FastSAM
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
"""
from pathlib import Path
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.engine.model import YOLO
from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from ultralytics.engine.model import Model
from .predict import FastSAMPredictor
from .val import FastSAMValidator
class FastSAM(YOLO):
def __init__(self, model='FastSAM-x.pt'):
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
if model == 'FastSAM.pt':
model = 'FastSAM-x.pt'
super().__init__(model=model)
# any additional initialization code for FastSAM
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs
overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides['mode'] in ['track', 'predict']
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
self.predictor = FastSAMPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model, verbose=False)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""Function trains models but raises an error as FastSAM models do not support training."""
raise NotImplementedError("FastSAM models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='segment', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = FastSAM(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
class FastSAM(Model):
"""
Export model.
FastSAM model interface.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Usage - Predict:
from ultralytics import FastSAM
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __init__(self, model='FastSAM-x.pt'):
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
if model == 'FastSAM.pt':
model = 'FastSAM-x.pt'
assert Path(model).suffix != '.yaml', 'FastSAM models only support pre-trained models.'
super().__init__(model=model, task='segment')
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def task_map(self):
return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}

@ -13,105 +13,36 @@ from pathlib import Path
import torch
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import NASPredictor
from .val import NASValidator
class NAS:
class NAS(Model):
def __init__(self, model='yolo_nas_s.pt') -> None:
assert Path(model).suffix != '.yaml', 'YOLO-NAS models only support pre-trained models.'
super().__init__(model, task='detect')
@smart_inference_mode()
def _load(self, weights: str, task: str):
# Load or create new NAS model
import super_gradients
self.predictor = None
suffix = Path(model).suffix
suffix = Path(weights).suffix
if suffix == '.pt':
self._load(model)
self.model = torch.load(weights)
elif suffix == '':
self.model = super_gradients.training.models.get(model, pretrained_weights='coco')
self.task = 'detect'
self.model.args = DEFAULT_CFG_DICT # attach args to model
self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
# Standardize model
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
self.model.names = dict(enumerate(self.model._class_names))
self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info()
self.model.pt_path = model # for export()
self.model.pt_path = weights # for export()
self.model.task = 'detect' # for export()
self.info()
@smart_inference_mode()
def _load(self, weights: str):
self.model = torch.load(weights)
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = dict(conf=0.25, task='detect', mode='predict')
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = NASPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""Function trains models but raises an error as NAS models do not support training."""
raise NotImplementedError("NAS models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='detect', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = NASValidator(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True):
"""
@ -123,11 +54,6 @@ class NAS:
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def task_map(self):
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}

@ -2,172 +2,29 @@
"""
RT-DETR model interface
"""
from pathlib import Path
import torch.nn as nn
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from ultralytics.engine.model import Model
from ultralytics.nn.tasks import RTDETRDetectionModel
from .predict import RTDETRPredictor
from .train import RTDETRTrainer
from .val import RTDETRValidator
class RTDETR:
class RTDETR(Model):
"""
RTDETR model interface.
"""
def __init__(self, model='rtdetr-l.pt') -> None:
if model and not model.endswith('.pt') and not model.endswith('.yaml'):
raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.')
# Load or create new YOLO model
self.predictor = None
self.ckpt = None
suffix = Path(model).suffix
if suffix == '.yaml':
self._new(model)
else:
self._load(model)
def _new(self, cfg: str, verbose=True):
cfg_dict = yaml_model_load(cfg)
self.cfg = cfg
self.task = 'detect'
self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model
# Below added to allow export from YAMLs
self.model.args = DEFAULT_CFG_DICT # attach args to model
self.model.task = self.task
@smart_inference_mode()
def _load(self, weights: str):
self.model, self.ckpt = attempt_load_one_weight(weights)
self.model.args = DEFAULT_CFG_DICT # attach args to model
self.task = self.model.args['task']
@smart_inference_mode()
def load(self, weights='yolov8n.pt'):
"""
Transfers parameters with matching names and shapes from 'weights' to model.
"""
if isinstance(weights, (str, Path)):
weights, self.ckpt = attempt_load_one_weight(weights)
self.model.load(weights)
return self
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = dict(conf=0.25, task='detect', mode='predict')
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = RTDETRPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""
Trains the model on a given dataset.
Args:
**kwargs (Any): Any number of arguments representing the training configuration.
"""
overrides = dict(task='detect', mode='train')
overrides.update(kwargs)
overrides['deterministic'] = False
if not overrides.get('data'):
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get('resume'):
overrides['resume'] = self.ckpt_path
self.task = overrides.get('task') or self.task
self.trainer = RTDETRTrainer(overrides=overrides)
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
# Update model and cfg after training
if RANK in (-1, 0):
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='detect', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = RTDETRValidator(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
def info(self, verbose=True):
"""Get model info"""
return model_info(self.model, verbose=verbose)
def _check_is_pytorch_model(self):
"""
Raises TypeError is model is not a PyTorch model
"""
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
pt_module = isinstance(self.model, nn.Module)
if not (pt_module or pt_str):
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
f'PyTorch models can be used to train, val, predict and export, i.e. '
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
def fuse(self):
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
self._check_is_pytorch_model()
self.model.fuse()
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
super().__init__(model=model, task='detect')
@property
def task_map(self):
return {
'detect': {
'predictor': RTDETRPredictor,
'validator': RTDETRValidator,
'trainer': RTDETRTrainer,
'model': RTDETRDetectionModel}}

@ -3,51 +3,38 @@
SAM model interface
"""
from ultralytics.cfg import get_cfg
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor
class SAM:
class SAM(Model):
"""
SAM model interface.
"""
def __init__(self, model='sam_b.pt') -> None:
if model and not model.endswith('.pt') and not model.endswith('.pth'):
# Should raise AssertionError instead?
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
self.model = build_sam(model)
self.task = 'segment' # required
self.predictor = None # reuse predictor
super().__init__(model=model, task='segment')
def _load(self, weights: str, task=None):
self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Predicts and returns segmentation masks for given image or video source."""
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = Predictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream, bboxes=bboxes, points=points, labels=labels)
def train(self, **kwargs):
"""Function trains models but raises an error as SAM models do not support training."""
raise NotImplementedError("SAM models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
raise NotImplementedError("SAM models don't support validation")
kwargs.update(overrides)
prompts = dict(bboxes=bboxes, points=points, labels=labels)
super().predict(source, stream, prompts=prompts, **kwargs)
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def info(self, detailed=False, verbose=True):
"""
Logs model info.
@ -57,3 +44,7 @@ class SAM:
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose)
@property
def task_map(self):
return {'segment': {'predictor': Predictor}}

@ -28,6 +28,8 @@ class Predictor(BasePredictor):
# Args for set_image
self.im = None
self.features = None
# Args for set_prompts
self.prompts = {}
# Args for segment everything
self.segment_all = False
@ -92,6 +94,10 @@ class Predictor(BasePredictor):
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
# Get prompts from self.prompts first
bboxes = self.prompts.pop('bboxes', bboxes)
points = self.prompts.pop('points', points)
masks = self.prompts.pop('masks', masks)
if all(i is None for i in [bboxes, points, masks]):
return self.generate(im, *args, **kwargs)
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
@ -348,6 +354,10 @@ class Predictor(BasePredictor):
self.im = im
break
def set_prompts(self, prompts):
"""Set prompts in advance."""
self.prompts = prompts
def reset_image(self):
self.im = None
self.features = None

@ -2,4 +2,6 @@
from ultralytics.models.yolo import classify, detect, pose, segment
__all__ = 'classify', 'segment', 'detect', 'pose'
from .model import YOLO
__all__ = 'classify', 'segment', 'detect', 'pose', 'YOLO'

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
import numpy as np

@ -0,0 +1,36 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.engine.model import Model
from ultralytics.models import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel
class YOLO(Model):
"""
YOLO (You Only Look Once) object detection model.
"""
@property
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes"""
return {
'classify': {
'model': ClassificationModel,
'trainer': yolo.classify.ClassificationTrainer,
'validator': yolo.classify.ClassificationValidator,
'predictor': yolo.classify.ClassificationPredictor, },
'detect': {
'model': DetectionModel,
'trainer': yolo.detect.DetectionTrainer,
'validator': yolo.detect.DetectionValidator,
'predictor': yolo.detect.DetectionPredictor, },
'segment': {
'model': SegmentationModel,
'trainer': yolo.segment.SegmentationTrainer,
'validator': yolo.segment.SegmentationValidator,
'predictor': yolo.segment.SegmentationPredictor, },
'pose': {
'model': PoseModel,
'trainer': yolo.pose.PoseTrainer,
'validator': yolo.pose.PoseValidator,
'predictor': yolo.pose.PosePredictor, }, }

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
from ultralytics.models import yolo

Loading…
Cancel
Save