ultralytics 8.0.122
Fix torch.Tensor
inference (#3363)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: krzysztof.gonia <4281421+kgonia@users.noreply.github.com>
This commit is contained in:
@ -295,11 +295,19 @@ class LoadPilAndNumpy:
|
||||
class LoadTensor:
|
||||
|
||||
def __init__(self, im0) -> None:
|
||||
self.im0 = im0
|
||||
self.bs = im0.shape[0]
|
||||
self.im0 = self._single_check(im0)
|
||||
self.bs = self.im0.shape[0]
|
||||
self.mode = 'image'
|
||||
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im):
|
||||
"""Validate and format an image to numpy array."""
|
||||
if len(im.shape) < 4:
|
||||
LOGGER.warning('WARNING ⚠️ torch.Tensor inputs should be BCHW format, i.e. shape(1,3,640,640).')
|
||||
im = im.unsqueeze(0)
|
||||
return im
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator object."""
|
||||
self.count = 0
|
||||
|
Reference in New Issue
Block a user