Merge model() and model.predict() (#146)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-05 14:13:29 +01:00
committed by GitHub
parent 99275814f1
commit 46cb657b64
2 changed files with 10 additions and 29 deletions

View File

@ -1,19 +1,17 @@
from pathlib import Path
import torch
from ultralytics import YOLO
from ultralytics.yolo.utils import ROOT, SETTINGS
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
CFG = 'yolov8n.yaml'
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
SOURCE = ROOT / 'assets/bus.jpg'
def test_model_forward():
model = YOLO(CFG)
img = torch.rand(1, 3, 320, 320)
model.forward(img)
model(img)
model.predict(SOURCE)
model(SOURCE)
def test_model_info():
@ -43,15 +41,13 @@ def test_val():
def test_train_scratch():
model = YOLO(CFG)
model.train(data="coco128.yaml", epochs=1, imgsz=32)
img = torch.rand(1, 3, 320, 320)
model(img)
model(SOURCE)
def test_train_pretrained():
model = YOLO(MODEL)
model.train(data="coco128.yaml", epochs=1, imgsz=32)
img = torch.rand(1, 3, 320, 320)
model(img)
model(SOURCE)
def test_export_torchscript():
@ -100,11 +96,3 @@ def test_export_paddle():
def test_all_model_yamls():
for m in list((ROOT / 'yolo/v8/models').rglob('*.yaml')):
YOLO(m.name)
# def run_all_tests(): # do not name function test_...
# pass
#
#
# if __name__ == "__main__":
# run_all_tests()