ultralytics 8.0.143 add Model base class (#3934)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing
2023-07-27 06:54:02 +08:00
committed by GitHub
parent 3c787eb080
commit 1a0eb3f099
15 changed files with 182 additions and 407 deletions

View File

@ -1,4 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .rtdetr import RTDETR
from .sam import SAM
from .yolo import YOLO
__all__ = 'RTDETR', 'SAM' # allow simpler import
__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import

View File

@ -1,111 +1,31 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
FastSAM model interface.
Usage - Predict:
from ultralytics import FastSAM
from pathlib import Path
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
"""
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.engine.model import YOLO
from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from ultralytics.engine.model import Model
from .predict import FastSAMPredictor
from .val import FastSAMValidator
class FastSAM(YOLO):
class FastSAM(Model):
"""
FastSAM model interface.
Usage - Predict:
from ultralytics import FastSAM
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
"""
def __init__(self, model='FastSAM-x.pt'):
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
if model == 'FastSAM.pt':
model = 'FastSAM-x.pt'
super().__init__(model=model)
# any additional initialization code for FastSAM
assert Path(model).suffix != '.yaml', 'FastSAM models only support pre-trained models.'
super().__init__(model=model, task='segment')
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs
overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides['mode'] in ['track', 'predict']
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
self.predictor = FastSAMPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model, verbose=False)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""Function trains models but raises an error as FastSAM models do not support training."""
raise NotImplementedError("FastSAM models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='segment', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = FastSAM(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def task_map(self):
return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}

View File

@ -13,105 +13,36 @@ from pathlib import Path
import torch
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import NASPredictor
from .val import NASValidator
class NAS:
class NAS(Model):
def __init__(self, model='yolo_nas_s.pt') -> None:
assert Path(model).suffix != '.yaml', 'YOLO-NAS models only support pre-trained models.'
super().__init__(model, task='detect')
@smart_inference_mode()
def _load(self, weights: str, task: str):
# Load or create new NAS model
import super_gradients
self.predictor = None
suffix = Path(model).suffix
suffix = Path(weights).suffix
if suffix == '.pt':
self._load(model)
self.model = torch.load(weights)
elif suffix == '':
self.model = super_gradients.training.models.get(model, pretrained_weights='coco')
self.task = 'detect'
self.model.args = DEFAULT_CFG_DICT # attach args to model
self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
# Standardize model
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
self.model.names = dict(enumerate(self.model._class_names))
self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info()
self.model.pt_path = model # for export()
self.model.pt_path = weights # for export()
self.model.task = 'detect' # for export()
self.info()
@smart_inference_mode()
def _load(self, weights: str):
self.model = torch.load(weights)
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = dict(conf=0.25, task='detect', mode='predict')
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = NASPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""Function trains models but raises an error as NAS models do not support training."""
raise NotImplementedError("NAS models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='detect', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = NASValidator(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True):
"""
@ -123,11 +54,6 @@ class NAS:
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def task_map(self):
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}

View File

@ -2,172 +2,29 @@
"""
RT-DETR model interface
"""
from pathlib import Path
import torch.nn as nn
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from ultralytics.engine.model import Model
from ultralytics.nn.tasks import RTDETRDetectionModel
from .predict import RTDETRPredictor
from .train import RTDETRTrainer
from .val import RTDETRValidator
class RTDETR:
class RTDETR(Model):
"""
RTDETR model interface.
"""
def __init__(self, model='rtdetr-l.pt') -> None:
if model and not model.endswith('.pt') and not model.endswith('.yaml'):
raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.')
# Load or create new YOLO model
self.predictor = None
self.ckpt = None
suffix = Path(model).suffix
if suffix == '.yaml':
self._new(model)
else:
self._load(model)
super().__init__(model=model, task='detect')
def _new(self, cfg: str, verbose=True):
cfg_dict = yaml_model_load(cfg)
self.cfg = cfg
self.task = 'detect'
self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model
# Below added to allow export from YAMLs
self.model.args = DEFAULT_CFG_DICT # attach args to model
self.model.task = self.task
@smart_inference_mode()
def _load(self, weights: str):
self.model, self.ckpt = attempt_load_one_weight(weights)
self.model.args = DEFAULT_CFG_DICT # attach args to model
self.task = self.model.args['task']
@smart_inference_mode()
def load(self, weights='yolov8n.pt'):
"""
Transfers parameters with matching names and shapes from 'weights' to model.
"""
if isinstance(weights, (str, Path)):
weights, self.ckpt = attempt_load_one_weight(weights)
self.model.load(weights)
return self
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = dict(conf=0.25, task='detect', mode='predict')
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = RTDETRPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""
Trains the model on a given dataset.
Args:
**kwargs (Any): Any number of arguments representing the training configuration.
"""
overrides = dict(task='detect', mode='train')
overrides.update(kwargs)
overrides['deterministic'] = False
if not overrides.get('data'):
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get('resume'):
overrides['resume'] = self.ckpt_path
self.task = overrides.get('task') or self.task
self.trainer = RTDETRTrainer(overrides=overrides)
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
# Update model and cfg after training
if RANK in (-1, 0):
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='detect', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = RTDETRValidator(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
def info(self, verbose=True):
"""Get model info"""
return model_info(self.model, verbose=verbose)
def _check_is_pytorch_model(self):
"""
Raises TypeError is model is not a PyTorch model
"""
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
pt_module = isinstance(self.model, nn.Module)
if not (pt_module or pt_str):
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
f'PyTorch models can be used to train, val, predict and export, i.e. '
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
def fuse(self):
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
self._check_is_pytorch_model()
self.model.fuse()
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def task_map(self):
return {
'detect': {
'predictor': RTDETRPredictor,
'validator': RTDETRValidator,
'trainer': RTDETRTrainer,
'model': RTDETRDetectionModel}}

View File

@ -3,51 +3,38 @@
SAM model interface
"""
from ultralytics.cfg import get_cfg
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor
class SAM:
class SAM(Model):
"""
SAM model interface.
"""
def __init__(self, model='sam_b.pt') -> None:
if model and not model.endswith('.pt') and not model.endswith('.pth'):
# Should raise AssertionError instead?
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
self.model = build_sam(model)
self.task = 'segment' # required
self.predictor = None # reuse predictor
super().__init__(model=model, task='segment')
def _load(self, weights: str, task=None):
self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Predicts and returns segmentation masks for given image or video source."""
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
overrides.update(kwargs) # prefer kwargs
if not self.predictor:
self.predictor = Predictor(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source, stream=stream, bboxes=bboxes, points=points, labels=labels)
def train(self, **kwargs):
"""Function trains models but raises an error as SAM models do not support training."""
raise NotImplementedError("SAM models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
raise NotImplementedError("SAM models don't support validation")
kwargs.update(overrides)
prompts = dict(bboxes=bboxes, points=points, labels=labels)
super().predict(source, stream, prompts=prompts, **kwargs)
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def info(self, detailed=False, verbose=True):
"""
Logs model info.
@ -57,3 +44,7 @@ class SAM:
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose)
@property
def task_map(self):
return {'segment': {'predictor': Predictor}}

View File

@ -28,6 +28,8 @@ class Predictor(BasePredictor):
# Args for set_image
self.im = None
self.features = None
# Args for set_prompts
self.prompts = {}
# Args for segment everything
self.segment_all = False
@ -92,6 +94,10 @@ class Predictor(BasePredictor):
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
# Get prompts from self.prompts first
bboxes = self.prompts.pop('bboxes', bboxes)
points = self.prompts.pop('points', points)
masks = self.prompts.pop('masks', masks)
if all(i is None for i in [bboxes, points, masks]):
return self.generate(im, *args, **kwargs)
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
@ -348,6 +354,10 @@ class Predictor(BasePredictor):
self.im = im
break
def set_prompts(self, prompts):
"""Set prompts in advance."""
self.prompts = prompts
def reset_image(self):
self.im = None
self.features = None

View File

@ -2,4 +2,6 @@
from ultralytics.models.yolo import classify, detect, pose, segment
__all__ = 'classify', 'segment', 'detect', 'pose'
from .model import YOLO
__all__ = 'classify', 'segment', 'detect', 'pose', 'YOLO'

View File

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
import numpy as np

View File

@ -0,0 +1,36 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.engine.model import Model
from ultralytics.models import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel
class YOLO(Model):
"""
YOLO (You Only Look Once) object detection model.
"""
@property
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes"""
return {
'classify': {
'model': ClassificationModel,
'trainer': yolo.classify.ClassificationTrainer,
'validator': yolo.classify.ClassificationValidator,
'predictor': yolo.classify.ClassificationPredictor, },
'detect': {
'model': DetectionModel,
'trainer': yolo.detect.DetectionTrainer,
'validator': yolo.detect.DetectionValidator,
'predictor': yolo.detect.DetectionPredictor, },
'segment': {
'model': SegmentationModel,
'trainer': yolo.segment.SegmentationTrainer,
'validator': yolo.segment.SegmentationValidator,
'predictor': yolo.segment.SegmentationPredictor, },
'pose': {
'model': PoseModel,
'trainer': yolo.pose.PoseTrainer,
'validator': yolo.pose.PoseValidator,
'predictor': yolo.pose.PosePredictor, }, }

View File

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
from ultralytics.models import yolo