|
|
|
@ -3,12 +3,13 @@
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
from ultralytics import yolo # noqa
|
|
|
|
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
|
|
|
|
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
|
|
|
|
guess_model_task)
|
|
|
|
|
from ultralytics.yolo.cfg import get_cfg
|
|
|
|
|
from ultralytics.yolo.engine.exporter import Exporter
|
|
|
|
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, yaml_load
|
|
|
|
|
from ultralytics.yolo.utils.checks import check_yaml
|
|
|
|
|
from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, smart_inference_mode
|
|
|
|
|
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
|
|
|
|
|
|
|
|
|
# Map head to model, trainer, validator, and predictor classes
|
|
|
|
|
MODEL_MAP = {
|
|
|
|
@ -73,9 +74,9 @@ class YOLO:
|
|
|
|
|
"""
|
|
|
|
|
cfg = check_yaml(cfg) # check YAML
|
|
|
|
|
cfg_dict = yaml_load(cfg, append_filename=True) # model dict
|
|
|
|
|
self.task = guess_task_from_model_yaml(cfg_dict)
|
|
|
|
|
self.task = guess_model_task(cfg_dict)
|
|
|
|
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
|
|
|
|
self._guess_ops_from_task(self.task)
|
|
|
|
|
self._assign_ops_from_task(self.task)
|
|
|
|
|
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
|
|
|
|
|
self.cfg = cfg
|
|
|
|
|
|
|
|
|
@ -92,7 +93,7 @@ class YOLO:
|
|
|
|
|
self.overrides = self.model.args
|
|
|
|
|
self._reset_ckpt_args(self.overrides)
|
|
|
|
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
|
|
|
|
self._guess_ops_from_task(self.task)
|
|
|
|
|
self._assign_ops_from_task(self.task)
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
|
"""
|
|
|
|
@ -217,7 +218,7 @@ class YOLO:
|
|
|
|
|
"""
|
|
|
|
|
self.model.to(device)
|
|
|
|
|
|
|
|
|
|
def _guess_ops_from_task(self, task):
|
|
|
|
|
def _assign_ops_from_task(self, task):
|
|
|
|
|
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
|
|
|
|
|
# warning: eval is unsafe. Use with caution
|
|
|
|
|
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
|
|
|
|