Support prediction of list of sources, in-memory dataset and other improvements (#685)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2023-01-29 05:44:26 +05:30
committed by GitHub
parent a5410ed79e
commit 0609561549
9 changed files with 174 additions and 73 deletions

View File

@ -3,10 +3,12 @@
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from ultralytics import YOLO
from ultralytics.yolo.data.build import load_inference_source
from ultralytics.yolo.utils import ROOT, SETTINGS
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
@ -40,6 +42,7 @@ def test_predict_dir():
def test_predict_img():
model = YOLO(MODEL)
img = Image.open(str(SOURCE))
output = model(source=img, save=True, verbose=True) # PIL
@ -54,6 +57,16 @@ def test_predict_img():
tens = torch.zeros(320, 640, 3)
output = model(tens.numpy())
assert len(output) == 1, "predict test failed"
# test multiple source
imgs = [
SOURCE, # filename
Path(SOURCE), # Path
'https://ultralytics.com/images/zidane.jpg', # URI
cv2.imread(str(SOURCE)), # OpenCV
Image.open(SOURCE), # PIL
np.zeros((320, 640, 3))] # numpy
output = model(imgs)
assert len(output) == 6, "predict test failed!"
def test_val():
@ -129,3 +142,28 @@ def test_workflow():
model.val()
model.predict(SOURCE)
model.export(format="onnx", opset=12) # export a model to ONNX format
def test_predict_callback_and_setup():
def on_predict_batch_end(predictor):
# results -> List[batch_size]
path, _, im0s, _, _ = predictor.batch
# print('on_predict_batch_end', im0s[0].shape)
bs = [predictor.bs for i in range(0, len(path))]
predictor.results = zip(predictor.results, im0s, bs)
model = YOLO("yolov8n.pt")
model.add_callback("on_predict_batch_end", on_predict_batch_end)
dataset = load_inference_source(source=SOURCE, transforms=model.transforms)
bs = dataset.bs # access predictor properties
results = model.predict(dataset, stream=True) # source already setup
for _, (result, im0, bs) in enumerate(results):
print('test_callback', im0.shape)
print('test_callback', bs)
boxes = result.boxes # Boxes object for bbox outputs
print(boxes)
test_predict_img()