ultralytics 8.0.122 Fix torch.Tensor inference (#3363)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: krzysztof.gonia <4281421+kgonia@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-06-25 01:36:07 +02:00
committed by GitHub
parent 51d8cfa9c3
commit 682c9ef70f
16 changed files with 471 additions and 154 deletions

View File

@ -295,11 +295,19 @@ class LoadPilAndNumpy:
class LoadTensor:
def __init__(self, im0) -> None:
self.im0 = im0
self.bs = im0.shape[0]
self.im0 = self._single_check(im0)
self.bs = self.im0.shape[0]
self.mode = 'image'
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
@staticmethod
def _single_check(im):
"""Validate and format an image to numpy array."""
if len(im.shape) < 4:
LOGGER.warning('WARNING ⚠️ torch.Tensor inputs should be BCHW format, i.e. shape(1,3,640,640).')
im = im.unsqueeze(0)
return im
def __iter__(self):
"""Returns an iterator object."""
self.count = 0