ultralytics 8.0.122
Fix torch.Tensor
inference (#3363)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: krzysztof.gonia <4281421+kgonia@users.noreply.github.com>
This commit is contained in:
@ -6,6 +6,7 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from ultralytics import RTDETR, YOLO
|
||||
from ultralytics.yolo.data.build import load_inference_source
|
||||
@ -70,7 +71,7 @@ def test_predict_img():
|
||||
# Test tensor inference
|
||||
im = cv2.imread(str(SOURCE)) # OpenCV
|
||||
t = cv2.resize(im, (32, 32))
|
||||
t = torch.from_numpy(t.transpose((2, 0, 1)))
|
||||
t = ToTensor()(t)
|
||||
t = torch.stack([t, t, t, t])
|
||||
results = model(t, visualize=True)
|
||||
assert len(results) == t.shape[0]
|
||||
|
Reference in New Issue
Block a user