Merge model()
and model.predict()
(#146)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,12 +1,10 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
||||
|
||||
@ -55,6 +53,9 @@ class YOLO:
|
||||
# Load or create new YOLO model
|
||||
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
|
||||
|
||||
def __call__(self, source):
|
||||
return self.predict(source)
|
||||
|
||||
def _new(self, cfg: str, verbose=True):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
@ -211,14 +212,6 @@ class YOLO:
|
||||
|
||||
return model_class, trainer_class, validator_class, predictor_class
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, imgs):
|
||||
device = next(self.model.parameters()).device # get model device
|
||||
return self.model(imgs.to(device))
|
||||
|
||||
def forward(self, imgs):
|
||||
return self.__call__(imgs)
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
args.pop("device", None)
|
||||
|
Reference in New Issue
Block a user