ultralytics 8.0.122 Fix torch.Tensor inference (#3363)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: krzysztof.gonia <4281421+kgonia@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-06-25 01:36:07 +02:00
committed by GitHub
parent 51d8cfa9c3
commit 682c9ef70f
16 changed files with 471 additions and 154 deletions

View File

@ -6,6 +6,7 @@ import cv2
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from ultralytics import RTDETR, YOLO
from ultralytics.yolo.data.build import load_inference_source
@ -70,7 +71,7 @@ def test_predict_img():
# Test tensor inference
im = cv2.imread(str(SOURCE)) # OpenCV
t = cv2.resize(im, (32, 32))
t = torch.from_numpy(t.transpose((2, 0, 1)))
t = ToTensor()(t)
t = torch.stack([t, t, t, t])
results = model(t, visualize=True)
assert len(results) == t.shape[0]