Merge model()
and model.predict()
(#146)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,19 +1,17 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.yolo.utils import ROOT, SETTINGS
|
||||
|
||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
||||
CFG = 'yolov8n.yaml'
|
||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
||||
SOURCE = ROOT / 'assets/bus.jpg'
|
||||
|
||||
|
||||
def test_model_forward():
|
||||
model = YOLO(CFG)
|
||||
img = torch.rand(1, 3, 320, 320)
|
||||
model.forward(img)
|
||||
model(img)
|
||||
model.predict(SOURCE)
|
||||
model(SOURCE)
|
||||
|
||||
|
||||
def test_model_info():
|
||||
@ -43,15 +41,13 @@ def test_val():
|
||||
def test_train_scratch():
|
||||
model = YOLO(CFG)
|
||||
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
||||
img = torch.rand(1, 3, 320, 320)
|
||||
model(img)
|
||||
model(SOURCE)
|
||||
|
||||
|
||||
def test_train_pretrained():
|
||||
model = YOLO(MODEL)
|
||||
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
||||
img = torch.rand(1, 3, 320, 320)
|
||||
model(img)
|
||||
model(SOURCE)
|
||||
|
||||
|
||||
def test_export_torchscript():
|
||||
@ -100,11 +96,3 @@ def test_export_paddle():
|
||||
def test_all_model_yamls():
|
||||
for m in list((ROOT / 'yolo/v8/models').rglob('*.yaml')):
|
||||
YOLO(m.name)
|
||||
|
||||
|
||||
# def run_all_tests(): # do not name function test_...
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# run_all_tests()
|
||||
|
Reference in New Issue
Block a user