From 46cb657b6495bcdd250cb1aef1494a3636ed5be5 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 5 Jan 2023 14:13:29 +0100 Subject: [PATCH] Merge `model()` and `model.predict()` (#146) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/test_python.py | 24 ++++++------------------ ultralytics/yolo/engine/model.py | 15 ++++----------- 2 files changed, 10 insertions(+), 29 deletions(-) diff --git a/tests/test_python.py b/tests/test_python.py index 571a466..4809022 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -1,19 +1,17 @@ from pathlib import Path -import torch - from ultralytics import YOLO from ultralytics.yolo.utils import ROOT, SETTINGS -MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt' CFG = 'yolov8n.yaml' +MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt' +SOURCE = ROOT / 'assets/bus.jpg' def test_model_forward(): model = YOLO(CFG) - img = torch.rand(1, 3, 320, 320) - model.forward(img) - model(img) + model.predict(SOURCE) + model(SOURCE) def test_model_info(): @@ -43,15 +41,13 @@ def test_val(): def test_train_scratch(): model = YOLO(CFG) model.train(data="coco128.yaml", epochs=1, imgsz=32) - img = torch.rand(1, 3, 320, 320) - model(img) + model(SOURCE) def test_train_pretrained(): model = YOLO(MODEL) model.train(data="coco128.yaml", epochs=1, imgsz=32) - img = torch.rand(1, 3, 320, 320) - model(img) + model(SOURCE) def test_export_torchscript(): @@ -100,11 +96,3 @@ def test_export_paddle(): def test_all_model_yamls(): for m in list((ROOT / 'yolo/v8/models').rglob('*.yaml')): YOLO(m.name) - - -# def run_all_tests(): # do not name function test_... -# pass -# -# -# if __name__ == "__main__": -# run_all_tests() diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index ddb1187..278b7e9 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -1,12 +1,10 @@ from pathlib import Path -import torch - from ultralytics import yolo # noqa from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights from ultralytics.yolo.configs import get_config from ultralytics.yolo.engine.exporter import Exporter -from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER, yaml_load +from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load from ultralytics.yolo.utils.checks import check_imgsz, check_yaml from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode @@ -55,6 +53,9 @@ class YOLO: # Load or create new YOLO model {'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model) + def __call__(self, source): + return self.predict(source) + def _new(self, cfg: str, verbose=True): """ Initializes a new model and infers the task type from the model definitions. @@ -211,14 +212,6 @@ class YOLO: return model_class, trainer_class, validator_class, predictor_class - @smart_inference_mode() - def __call__(self, imgs): - device = next(self.model.parameters()).device # get model device - return self.model(imgs.to(device)) - - def forward(self, imgs): - return self.__call__(imgs) - @staticmethod def _reset_ckpt_args(args): args.pop("device", None)