Add torch.Tensor
checks and pip badges (#3368)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -301,11 +301,23 @@ class LoadTensor:
|
||||
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im):
|
||||
"""Validate and format an image to numpy array."""
|
||||
if len(im.shape) < 4:
|
||||
LOGGER.warning('WARNING ⚠️ torch.Tensor inputs should be BCHW format, i.e. shape(1,3,640,640).')
|
||||
im = im.unsqueeze(0)
|
||||
def _single_check(im, stride=32):
|
||||
"""Validate and format an image to torch.Tensor."""
|
||||
s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
|
||||
f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
|
||||
if len(im.shape) != 4:
|
||||
if len(im.shape) == 3:
|
||||
LOGGER.warning(s)
|
||||
im = im.unsqueeze(0)
|
||||
else:
|
||||
raise ValueError(s)
|
||||
if im.shape[2] % stride or im.shape[3] % stride:
|
||||
raise ValueError(s)
|
||||
if im.max() > 1.0:
|
||||
LOGGER.warning(f'WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. '
|
||||
f'Dividing input by 255.')
|
||||
im = im.float() / 255.0
|
||||
|
||||
return im
|
||||
|
||||
def __iter__(self):
|
||||
|
Reference in New Issue
Block a user