ultralytics 8.0.122 Fix torch.Tensor inference (#3363)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: krzysztof.gonia <4281421+kgonia@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-06-25 01:36:07 +02:00
committed by GitHub
parent 51d8cfa9c3
commit 682c9ef70f
16 changed files with 471 additions and 154 deletions

View File

@ -295,11 +295,19 @@ class LoadPilAndNumpy:
class LoadTensor:
def __init__(self, im0) -> None:
self.im0 = im0
self.bs = im0.shape[0]
self.im0 = self._single_check(im0)
self.bs = self.im0.shape[0]
self.mode = 'image'
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
@staticmethod
def _single_check(im):
"""Validate and format an image to numpy array."""
if len(im.shape) < 4:
LOGGER.warning('WARNING ⚠️ torch.Tensor inputs should be BCHW format, i.e. shape(1,3,640,640).')
im = im.unsqueeze(0)
return im
def __iter__(self):
"""Returns an iterator object."""
self.count = 0

View File

@ -116,21 +116,23 @@ class BasePredictor:
"""Prepares input image before inference.
Args:
im (torch.Tensor | List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
"""
if not isinstance(im, torch.Tensor):
not_tensor = not isinstance(im, torch.Tensor)
if not_tensor:
im = np.stack(self.pre_transform(im))
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im)
# NOTE: assuming im with (b, 3, h, w) if it's a tensor
img = im.to(self.device)
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
img /= 255 # 0 - 255 to 0.0 - 1.0
if not_tensor:
img /= 255 # 0 - 255 to 0.0 - 1.0
return img
def pre_transform(self, im):
"""Pre-tranform input image before inference.
"""Pre-transform input image before inference.
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
@ -147,7 +149,7 @@ class BasePredictor:
log_string = ''
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
log_string += f'{idx}: '
frame = self.dataset.count
else:
@ -159,10 +161,11 @@ class BasePredictor:
log_string += result.verbose()
if self.args.save or self.args.show: # Add bbox to image
plot_args = dict(line_width=self.args.line_width,
boxes=self.args.boxes,
conf=self.args.show_conf,
labels=self.args.show_labels)
plot_args = {
'line_width': self.args.line_width,
'boxes': self.args.boxes,
'conf': self.args.show_conf,
'labels': self.args.show_labels}
if not self.args.retina_masks:
plot_args['im_gpu'] = im[idx]
self.plotted_img = result.plot(**plot_args)
@ -214,17 +217,23 @@ class BasePredictor:
# Setup model
if not self.model:
self.setup_model(model)
# Setup source every time predict is called
self.setup_source(source if source is not None else self.args.source)
# Check if save_dir/ label file exists
if self.args.save or self.args.save_txt:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# Warmup model
if not self.done_warmup:
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
self.done_warmup = True
# Checks
if self.source_type.tensor and (self.args.save or self.args.save_txt or self.args.show):
LOGGER.warning("WARNING ⚠️ 'save', 'save_txt' and 'show' arguments not enabled for torch.Tensor inference.")
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
self.run_callbacks('on_predict_start')
for batch in self.dataset:
@ -255,11 +264,7 @@ class BasePredictor:
'preprocess': profilers[0].dt * 1E3 / n,
'inference': profilers[1].dt * 1E3 / n,
'postprocess': profilers[2].dt * 1E3 / n}
if self.source_type.tensor: # skip write, show and plot operations if input is raw tensor
if self.args.save or self.args.save_txt or self.args.show:
LOGGER.warning('WARNING ⚠️ save, save_txt and show argument not enabled for tensor inference.')
continue
p, im0 = path[i], im0s[i].copy()
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
p = Path(p)
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
@ -286,7 +291,7 @@ class BasePredictor:
if self.args.verbose and self.seen:
t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
f'{(1, 3, *self.imgsz)}' % t)
f'{(1, 3, *im.shape[2:])}' % t)
if self.args.save or self.args.save_txt or self.args.save_crop:
nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''

View File

@ -198,6 +198,10 @@ class Results(SimpleClass):
Returns:
(numpy.ndarray): A numpy array of the annotated image.
"""
if img is None and isinstance(self.orig_img, torch.Tensor):
LOGGER.warning('WARNING ⚠️ Results plotting is not supported for torch.Tensor image types.')
return
# Deprecation warn TODO: remove in 8.2
if 'show_conf' in kwargs:
deprecation_warn('show_conf', 'conf')
@ -305,7 +309,7 @@ class Results(SimpleClass):
file_name (str | pathlib.Path): File name.
"""
if self.probs is not None:
LOGGER.warning('Warning: Classify task do not support `save_crop`.')
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
return
if isinstance(save_dir, str):
save_dir = Path(save_dir)