`ultralytics 8.0.143` add `Model` base class (#3934)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>single_channel
parent
3c787eb080
commit
1a0eb3f099
@ -0,0 +1,9 @@
|
|||||||
|
---
|
||||||
|
description: Discover the Ultralytics YOLO model class. Learn advanced techniques, tips, and tricks for training.
|
||||||
|
keywords: Ultralytics YOLO, YOLO, YOLO model, Model Training, Machine Learning, Deep Learning, Computer Vision
|
||||||
|
---
|
||||||
|
|
||||||
|
## YOLO
|
||||||
|
---
|
||||||
|
### ::: ultralytics.models.yolo.model.YOLO
|
||||||
|
<br><br>
|
@ -1,4 +1,7 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
from .rtdetr import RTDETR
|
from .rtdetr import RTDETR
|
||||||
from .sam import SAM
|
from .sam import SAM
|
||||||
|
from .yolo import YOLO
|
||||||
|
|
||||||
__all__ = 'RTDETR', 'SAM' # allow simpler import
|
__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import
|
||||||
|
@ -1,111 +1,31 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
"""
|
|
||||||
FastSAM model interface.
|
|
||||||
|
|
||||||
Usage - Predict:
|
from pathlib import Path
|
||||||
from ultralytics import FastSAM
|
|
||||||
|
|
||||||
model = FastSAM('last.pt')
|
from ultralytics.engine.model import Model
|
||||||
results = model.predict('ultralytics/assets/bus.jpg')
|
|
||||||
"""
|
|
||||||
|
|
||||||
from ultralytics.cfg import get_cfg
|
|
||||||
from ultralytics.engine.exporter import Exporter
|
|
||||||
from ultralytics.engine.model import YOLO
|
|
||||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
|
|
||||||
from ultralytics.utils.checks import check_imgsz
|
|
||||||
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
|
||||||
|
|
||||||
from .predict import FastSAMPredictor
|
from .predict import FastSAMPredictor
|
||||||
|
from .val import FastSAMValidator
|
||||||
|
|
||||||
|
|
||||||
|
class FastSAM(Model):
|
||||||
|
"""
|
||||||
|
FastSAM model interface.
|
||||||
|
|
||||||
class FastSAM(YOLO):
|
Usage - Predict:
|
||||||
|
from ultralytics import FastSAM
|
||||||
|
|
||||||
|
model = FastSAM('last.pt')
|
||||||
|
results = model.predict('ultralytics/assets/bus.jpg')
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, model='FastSAM-x.pt'):
|
def __init__(self, model='FastSAM-x.pt'):
|
||||||
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
|
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
|
||||||
if model == 'FastSAM.pt':
|
if model == 'FastSAM.pt':
|
||||||
model = 'FastSAM-x.pt'
|
model = 'FastSAM-x.pt'
|
||||||
super().__init__(model=model)
|
assert Path(model).suffix != '.yaml', 'FastSAM models only support pre-trained models.'
|
||||||
# any additional initialization code for FastSAM
|
super().__init__(model=model, task='segment')
|
||||||
|
|
||||||
@smart_inference_mode()
|
|
||||||
def predict(self, source=None, stream=False, **kwargs):
|
|
||||||
"""
|
|
||||||
Perform prediction using the YOLO model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
|
||||||
Accepts all source types accepted by the YOLO model.
|
|
||||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
|
||||||
**kwargs : Additional keyword arguments passed to the predictor.
|
|
||||||
Check the 'configuration' section in the documentation for all available options.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(List[ultralytics.engine.results.Results]): The prediction results.
|
|
||||||
"""
|
|
||||||
if source is None:
|
|
||||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
|
||||||
overrides = self.overrides.copy()
|
|
||||||
overrides['conf'] = 0.25
|
|
||||||
overrides.update(kwargs) # prefer kwargs
|
|
||||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
|
||||||
assert overrides['mode'] in ['track', 'predict']
|
|
||||||
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
|
||||||
self.predictor = FastSAMPredictor(overrides=overrides)
|
|
||||||
self.predictor.setup_model(model=self.model, verbose=False)
|
|
||||||
|
|
||||||
return self.predictor(source, stream=stream)
|
|
||||||
|
|
||||||
def train(self, **kwargs):
|
|
||||||
"""Function trains models but raises an error as FastSAM models do not support training."""
|
|
||||||
raise NotImplementedError("FastSAM models don't support training")
|
|
||||||
|
|
||||||
def val(self, **kwargs):
|
|
||||||
"""Run validation given dataset."""
|
|
||||||
overrides = dict(task='segment', mode='val')
|
|
||||||
overrides.update(kwargs) # prefer kwargs
|
|
||||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
|
||||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
|
||||||
validator = FastSAM(args=args)
|
|
||||||
validator(model=self.model)
|
|
||||||
self.metrics = validator.metrics
|
|
||||||
return validator.metrics
|
|
||||||
|
|
||||||
@smart_inference_mode()
|
|
||||||
def export(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Export model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
|
||||||
"""
|
|
||||||
overrides = dict(task='detect')
|
|
||||||
overrides.update(kwargs)
|
|
||||||
overrides['mode'] = 'export'
|
|
||||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
|
||||||
args.task = self.task
|
|
||||||
if args.imgsz == DEFAULT_CFG.imgsz:
|
|
||||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
|
||||||
if args.batch == DEFAULT_CFG.batch:
|
|
||||||
args.batch = 1 # default to 1 if not modified
|
|
||||||
return Exporter(overrides=args)(model=self.model)
|
|
||||||
|
|
||||||
def info(self, detailed=False, verbose=True):
|
|
||||||
"""
|
|
||||||
Logs model info.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
detailed (bool): Show detailed information about model.
|
|
||||||
verbose (bool): Controls verbosity.
|
|
||||||
"""
|
|
||||||
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
|
||||||
|
|
||||||
def __call__(self, source=None, stream=False, **kwargs):
|
|
||||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
|
||||||
return self.predict(source, stream, **kwargs)
|
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
@property
|
||||||
"""Raises error if object has no requested attribute."""
|
def task_map(self):
|
||||||
name = self.__class__.__name__
|
return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}
|
||||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
from ultralytics.engine.model import Model
|
||||||
|
from ultralytics.models import yolo # noqa
|
||||||
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel
|
||||||
|
|
||||||
|
|
||||||
|
class YOLO(Model):
|
||||||
|
"""
|
||||||
|
YOLO (You Only Look Once) object detection model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_map(self):
|
||||||
|
"""Map head to model, trainer, validator, and predictor classes"""
|
||||||
|
return {
|
||||||
|
'classify': {
|
||||||
|
'model': ClassificationModel,
|
||||||
|
'trainer': yolo.classify.ClassificationTrainer,
|
||||||
|
'validator': yolo.classify.ClassificationValidator,
|
||||||
|
'predictor': yolo.classify.ClassificationPredictor, },
|
||||||
|
'detect': {
|
||||||
|
'model': DetectionModel,
|
||||||
|
'trainer': yolo.detect.DetectionTrainer,
|
||||||
|
'validator': yolo.detect.DetectionValidator,
|
||||||
|
'predictor': yolo.detect.DetectionPredictor, },
|
||||||
|
'segment': {
|
||||||
|
'model': SegmentationModel,
|
||||||
|
'trainer': yolo.segment.SegmentationTrainer,
|
||||||
|
'validator': yolo.segment.SegmentationValidator,
|
||||||
|
'predictor': yolo.segment.SegmentationPredictor, },
|
||||||
|
'pose': {
|
||||||
|
'model': PoseModel,
|
||||||
|
'trainer': yolo.pose.PoseTrainer,
|
||||||
|
'validator': yolo.pose.PoseValidator,
|
||||||
|
'predictor': yolo.pose.PosePredictor, }, }
|
Loading…
Reference in new issue