|
|
@ -1,36 +1,23 @@
|
|
|
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
|
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
|
from pathlib import Path
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Union
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
|
|
|
|
from ultralytics.cfg import get_cfg
|
|
|
|
from ultralytics.cfg import get_cfg
|
|
|
|
from ultralytics.engine.exporter import Exporter
|
|
|
|
from ultralytics.engine.exporter import Exporter
|
|
|
|
from ultralytics.models import yolo # noqa
|
|
|
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
|
|
|
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel,
|
|
|
|
|
|
|
|
attempt_load_one_weight, guess_model_task, nn, yaml_model_load)
|
|
|
|
|
|
|
|
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
|
|
|
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
|
|
|
is_git_dir, yaml_load)
|
|
|
|
is_git_dir, yaml_load)
|
|
|
|
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
|
|
|
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
|
|
|
from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
|
|
|
|
from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
|
|
|
|
from ultralytics.utils.torch_utils import smart_inference_mode
|
|
|
|
from ultralytics.utils.torch_utils import smart_inference_mode
|
|
|
|
|
|
|
|
|
|
|
|
# Map head to model, trainer, validator, and predictor classes
|
|
|
|
|
|
|
|
TASK_MAP = {
|
|
|
|
|
|
|
|
'classify': [
|
|
|
|
|
|
|
|
ClassificationModel, yolo.classify.ClassificationTrainer, yolo.classify.ClassificationValidator,
|
|
|
|
|
|
|
|
yolo.classify.ClassificationPredictor],
|
|
|
|
|
|
|
|
'detect':
|
|
|
|
|
|
|
|
[DetectionModel, yolo.detect.DetectionTrainer, yolo.detect.DetectionValidator, yolo.detect.DetectionPredictor],
|
|
|
|
|
|
|
|
'segment': [
|
|
|
|
|
|
|
|
SegmentationModel, yolo.segment.SegmentationTrainer, yolo.segment.SegmentationValidator,
|
|
|
|
|
|
|
|
yolo.segment.SegmentationPredictor],
|
|
|
|
|
|
|
|
'pose': [PoseModel, yolo.pose.PoseTrainer, yolo.pose.PoseValidator, yolo.pose.PosePredictor]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model:
|
|
|
|
class YOLO:
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
YOLO (You Only Look Once) object detection model.
|
|
|
|
A base model class to unify apis for all the models.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
model (str, Path): Path to the model file to load or create.
|
|
|
|
model (str, Path): Path to the model file to load or create.
|
|
|
@ -81,13 +68,13 @@ class YOLO:
|
|
|
|
self.predictor = None # reuse predictor
|
|
|
|
self.predictor = None # reuse predictor
|
|
|
|
self.model = None # model object
|
|
|
|
self.model = None # model object
|
|
|
|
self.trainer = None # trainer object
|
|
|
|
self.trainer = None # trainer object
|
|
|
|
self.task = None # task type
|
|
|
|
|
|
|
|
self.ckpt = None # if loaded from *.pt
|
|
|
|
self.ckpt = None # if loaded from *.pt
|
|
|
|
self.cfg = None # if loaded from *.yaml
|
|
|
|
self.cfg = None # if loaded from *.yaml
|
|
|
|
self.ckpt_path = None
|
|
|
|
self.ckpt_path = None
|
|
|
|
self.overrides = {} # overrides for trainer object
|
|
|
|
self.overrides = {} # overrides for trainer object
|
|
|
|
self.metrics = None # validation/training metrics
|
|
|
|
self.metrics = None # validation/training metrics
|
|
|
|
self.session = None # HUB session
|
|
|
|
self.session = None # HUB session
|
|
|
|
|
|
|
|
self.task = task # task type
|
|
|
|
model = str(model).strip() # strip spaces
|
|
|
|
model = str(model).strip() # strip spaces
|
|
|
|
|
|
|
|
|
|
|
|
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
|
|
|
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
|
|
@ -109,11 +96,6 @@ class YOLO:
|
|
|
|
"""Calls the 'predict' function with given arguments to perform object detection."""
|
|
|
|
"""Calls the 'predict' function with given arguments to perform object detection."""
|
|
|
|
return self.predict(source, stream, **kwargs)
|
|
|
|
return self.predict(source, stream, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, attr):
|
|
|
|
|
|
|
|
"""Raises error if object has no requested attribute."""
|
|
|
|
|
|
|
|
name = self.__class__.__name__
|
|
|
|
|
|
|
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def is_hub_model(model):
|
|
|
|
def is_hub_model(model):
|
|
|
|
"""Check if the provided model is a HUB model."""
|
|
|
|
"""Check if the provided model is a HUB model."""
|
|
|
@ -122,19 +104,21 @@ class YOLO:
|
|
|
|
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
|
|
|
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
|
|
|
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
|
|
|
|
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
|
|
|
|
|
|
|
|
|
|
|
|
def _new(self, cfg: str, task=None, verbose=True):
|
|
|
|
def _new(self, cfg: str, task=None, model=None, verbose=True):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Initializes a new model and infers the task type from the model definitions.
|
|
|
|
Initializes a new model and infers the task type from the model definitions.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
cfg (str): model configuration file
|
|
|
|
cfg (str): model configuration file
|
|
|
|
task (str | None): model task
|
|
|
|
task (str | None): model task
|
|
|
|
|
|
|
|
model (BaseModel): Customized model.
|
|
|
|
verbose (bool): display model info on load
|
|
|
|
verbose (bool): display model info on load
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
cfg_dict = yaml_model_load(cfg)
|
|
|
|
cfg_dict = yaml_model_load(cfg)
|
|
|
|
self.cfg = cfg
|
|
|
|
self.cfg = cfg
|
|
|
|
self.task = task or guess_model_task(cfg_dict)
|
|
|
|
self.task = task or guess_model_task(cfg_dict)
|
|
|
|
self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
|
|
|
|
model = model or self.smart_load('model')
|
|
|
|
|
|
|
|
self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
|
|
|
|
self.overrides['model'] = self.cfg
|
|
|
|
self.overrides['model'] = self.cfg
|
|
|
|
|
|
|
|
|
|
|
|
# Below added to allow export from yamls
|
|
|
|
# Below added to allow export from yamls
|
|
|
@ -217,7 +201,7 @@ class YOLO:
|
|
|
|
self.model.fuse()
|
|
|
|
self.model.fuse()
|
|
|
|
|
|
|
|
|
|
|
|
@smart_inference_mode()
|
|
|
|
@smart_inference_mode()
|
|
|
|
def predict(self, source=None, stream=False, **kwargs):
|
|
|
|
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Perform prediction using the YOLO model.
|
|
|
|
Perform prediction using the YOLO model.
|
|
|
|
|
|
|
|
|
|
|
@ -225,6 +209,7 @@ class YOLO:
|
|
|
|
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
|
|
|
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
|
|
|
Accepts all source types accepted by the YOLO model.
|
|
|
|
Accepts all source types accepted by the YOLO model.
|
|
|
|
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
|
|
|
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
|
|
|
|
|
|
|
predictor (BasePredictor): Customized predictor.
|
|
|
|
**kwargs : Additional keyword arguments passed to the predictor.
|
|
|
|
**kwargs : Additional keyword arguments passed to the predictor.
|
|
|
|
Check the 'configuration' section in the documentation for all available options.
|
|
|
|
Check the 'configuration' section in the documentation for all available options.
|
|
|
|
|
|
|
|
|
|
|
@ -236,6 +221,8 @@ class YOLO:
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
|
|
|
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
|
|
|
|
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
|
|
|
|
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
|
|
|
|
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
|
|
|
|
|
|
|
|
# Check prompts for SAM/FastSAM
|
|
|
|
|
|
|
|
prompts = kwargs.pop('prompts', None)
|
|
|
|
overrides = self.overrides.copy()
|
|
|
|
overrides = self.overrides.copy()
|
|
|
|
overrides['conf'] = 0.25
|
|
|
|
overrides['conf'] = 0.25
|
|
|
|
overrides.update(kwargs) # prefer kwargs
|
|
|
|
overrides.update(kwargs) # prefer kwargs
|
|
|
@ -245,12 +232,16 @@ class YOLO:
|
|
|
|
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
|
|
|
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
|
|
|
if not self.predictor:
|
|
|
|
if not self.predictor:
|
|
|
|
self.task = overrides.get('task') or self.task
|
|
|
|
self.task = overrides.get('task') or self.task
|
|
|
|
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
|
|
|
|
predictor = predictor or self.smart_load('predictor')
|
|
|
|
|
|
|
|
self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks)
|
|
|
|
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
|
|
|
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
|
|
|
else: # only update args if predictor is already setup
|
|
|
|
else: # only update args if predictor is already setup
|
|
|
|
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
|
|
|
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
|
|
|
if 'project' in overrides or 'name' in overrides:
|
|
|
|
if 'project' in overrides or 'name' in overrides:
|
|
|
|
self.predictor.save_dir = self.predictor.get_save_dir()
|
|
|
|
self.predictor.save_dir = self.predictor.get_save_dir()
|
|
|
|
|
|
|
|
# Set prompts for SAM/FastSAM
|
|
|
|
|
|
|
|
if len and hasattr(self.predictor, 'set_prompts'):
|
|
|
|
|
|
|
|
self.predictor.set_prompts(prompts)
|
|
|
|
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
|
|
|
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
|
|
|
|
|
|
|
|
|
|
|
def track(self, source=None, stream=False, persist=False, **kwargs):
|
|
|
|
def track(self, source=None, stream=False, persist=False, **kwargs):
|
|
|
@ -277,12 +268,13 @@ class YOLO:
|
|
|
|
return self.predict(source=source, stream=stream, **kwargs)
|
|
|
|
return self.predict(source=source, stream=stream, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@smart_inference_mode()
|
|
|
|
@smart_inference_mode()
|
|
|
|
def val(self, data=None, **kwargs):
|
|
|
|
def val(self, data=None, validator=None, **kwargs):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Validate a model on a given dataset.
|
|
|
|
Validate a model on a given dataset.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
|
|
|
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
|
|
|
|
|
|
|
validator (BaseValidator): Customized validator.
|
|
|
|
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
|
|
|
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
overrides = self.overrides.copy()
|
|
|
|
overrides = self.overrides.copy()
|
|
|
@ -295,11 +287,12 @@ class YOLO:
|
|
|
|
self.task = args.task
|
|
|
|
self.task = args.task
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
args.task = self.task
|
|
|
|
args.task = self.task
|
|
|
|
|
|
|
|
validator = validator or self.smart_load('validator')
|
|
|
|
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
|
|
|
|
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
|
|
|
|
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
|
|
|
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
|
|
|
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
|
|
|
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks)
|
|
|
|
validator = validator(args=args, _callbacks=self.callbacks)
|
|
|
|
validator(model=self.model)
|
|
|
|
validator(model=self.model)
|
|
|
|
self.metrics = validator.metrics
|
|
|
|
self.metrics = validator.metrics
|
|
|
|
|
|
|
|
|
|
|
@ -349,11 +342,12 @@ class YOLO:
|
|
|
|
args.task = self.task
|
|
|
|
args.task = self.task
|
|
|
|
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
|
|
|
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
|
|
|
|
|
|
|
|
|
|
|
def train(self, **kwargs):
|
|
|
|
def train(self, trainer=None, **kwargs):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Trains the model on a given dataset.
|
|
|
|
Trains the model on a given dataset.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
|
|
|
|
trainer (BaseTrainer, optional): Customized trainer.
|
|
|
|
**kwargs (Any): Any number of arguments representing the training configuration.
|
|
|
|
**kwargs (Any): Any number of arguments representing the training configuration.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
self._check_is_pytorch_model()
|
|
|
|
self._check_is_pytorch_model()
|
|
|
@ -373,7 +367,8 @@ class YOLO:
|
|
|
|
if overrides.get('resume'):
|
|
|
|
if overrides.get('resume'):
|
|
|
|
overrides['resume'] = self.ckpt_path
|
|
|
|
overrides['resume'] = self.ckpt_path
|
|
|
|
self.task = overrides.get('task') or self.task
|
|
|
|
self.task = overrides.get('task') or self.task
|
|
|
|
self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
|
|
|
|
trainer = trainer or self.smart_load('trainer')
|
|
|
|
|
|
|
|
self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
|
|
|
|
if not overrides.get('resume'): # manually set model only if not resuming
|
|
|
|
if not overrides.get('resume'): # manually set model only if not resuming
|
|
|
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
|
|
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
|
|
|
self.model = self.trainer.model
|
|
|
|
self.model = self.trainer.model
|
|
|
@ -442,3 +437,27 @@ class YOLO:
|
|
|
|
"""Reset all registered callbacks."""
|
|
|
|
"""Reset all registered callbacks."""
|
|
|
|
for event in callbacks.default_callbacks.keys():
|
|
|
|
for event in callbacks.default_callbacks.keys():
|
|
|
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
|
|
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, attr):
|
|
|
|
|
|
|
|
"""Raises error if object has no requested attribute."""
|
|
|
|
|
|
|
|
name = self.__class__.__name__
|
|
|
|
|
|
|
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def smart_load(self, key):
|
|
|
|
|
|
|
|
"""Load model/trainer/validator/predictor."""
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
return self.task_map[self.task][key]
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
|
|
name = self.__class__.__name__
|
|
|
|
|
|
|
|
mode = inspect.stack()[1][3] # get the function name.
|
|
|
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
|
|
|
f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
|
|
def task_map(self):
|
|
|
|
|
|
|
|
"""Map head to model, trainer, validator, and predictor classes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
task_map (dict)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
raise NotImplementedError('Please provide task map for your model!')
|
|
|
|