Support prediction of list of sources, in-memory dataset and other improvements (#685)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -3,10 +3,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.yolo.data.build import load_inference_source
|
||||
from ultralytics.yolo.utils import ROOT, SETTINGS
|
||||
|
||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
||||
@ -40,6 +42,7 @@ def test_predict_dir():
|
||||
|
||||
|
||||
def test_predict_img():
|
||||
|
||||
model = YOLO(MODEL)
|
||||
img = Image.open(str(SOURCE))
|
||||
output = model(source=img, save=True, verbose=True) # PIL
|
||||
@ -54,6 +57,16 @@ def test_predict_img():
|
||||
tens = torch.zeros(320, 640, 3)
|
||||
output = model(tens.numpy())
|
||||
assert len(output) == 1, "predict test failed"
|
||||
# test multiple source
|
||||
imgs = [
|
||||
SOURCE, # filename
|
||||
Path(SOURCE), # Path
|
||||
'https://ultralytics.com/images/zidane.jpg', # URI
|
||||
cv2.imread(str(SOURCE)), # OpenCV
|
||||
Image.open(SOURCE), # PIL
|
||||
np.zeros((320, 640, 3))] # numpy
|
||||
output = model(imgs)
|
||||
assert len(output) == 6, "predict test failed!"
|
||||
|
||||
|
||||
def test_val():
|
||||
@ -129,3 +142,28 @@ def test_workflow():
|
||||
model.val()
|
||||
model.predict(SOURCE)
|
||||
model.export(format="onnx", opset=12) # export a model to ONNX format
|
||||
|
||||
|
||||
def test_predict_callback_and_setup():
|
||||
|
||||
def on_predict_batch_end(predictor):
|
||||
# results -> List[batch_size]
|
||||
path, _, im0s, _, _ = predictor.batch
|
||||
# print('on_predict_batch_end', im0s[0].shape)
|
||||
bs = [predictor.bs for i in range(0, len(path))]
|
||||
predictor.results = zip(predictor.results, im0s, bs)
|
||||
|
||||
model = YOLO("yolov8n.pt")
|
||||
model.add_callback("on_predict_batch_end", on_predict_batch_end)
|
||||
|
||||
dataset = load_inference_source(source=SOURCE, transforms=model.transforms)
|
||||
bs = dataset.bs # access predictor properties
|
||||
results = model.predict(dataset, stream=True) # source already setup
|
||||
for _, (result, im0, bs) in enumerate(results):
|
||||
print('test_callback', im0.shape)
|
||||
print('test_callback', bs)
|
||||
boxes = result.boxes # Boxes object for bbox outputs
|
||||
print(boxes)
|
||||
|
||||
|
||||
test_predict_img()
|
||||
|
Reference in New Issue
Block a user