ultralytics 8.0.122
Fix torch.Tensor
inference (#3363)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: krzysztof.gonia <4281421+kgonia@users.noreply.github.com>
This commit is contained in:
@ -295,11 +295,19 @@ class LoadPilAndNumpy:
|
||||
class LoadTensor:
|
||||
|
||||
def __init__(self, im0) -> None:
|
||||
self.im0 = im0
|
||||
self.bs = im0.shape[0]
|
||||
self.im0 = self._single_check(im0)
|
||||
self.bs = self.im0.shape[0]
|
||||
self.mode = 'image'
|
||||
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im):
|
||||
"""Validate and format an image to numpy array."""
|
||||
if len(im.shape) < 4:
|
||||
LOGGER.warning('WARNING ⚠️ torch.Tensor inputs should be BCHW format, i.e. shape(1,3,640,640).')
|
||||
im = im.unsqueeze(0)
|
||||
return im
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator object."""
|
||||
self.count = 0
|
||||
|
@ -116,21 +116,23 @@ class BasePredictor:
|
||||
"""Prepares input image before inference.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
||||
"""
|
||||
if not isinstance(im, torch.Tensor):
|
||||
not_tensor = not isinstance(im, torch.Tensor)
|
||||
if not_tensor:
|
||||
im = np.stack(self.pre_transform(im))
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
im = torch.from_numpy(im)
|
||||
# NOTE: assuming im with (b, 3, h, w) if it's a tensor
|
||||
|
||||
img = im.to(self.device)
|
||||
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
img /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
if not_tensor:
|
||||
img /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
return img
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""Pre-tranform input image before inference.
|
||||
"""Pre-transform input image before inference.
|
||||
|
||||
Args:
|
||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
@ -147,7 +149,7 @@ class BasePredictor:
|
||||
log_string = ''
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
|
||||
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
@ -159,10 +161,11 @@ class BasePredictor:
|
||||
log_string += result.verbose()
|
||||
|
||||
if self.args.save or self.args.show: # Add bbox to image
|
||||
plot_args = dict(line_width=self.args.line_width,
|
||||
boxes=self.args.boxes,
|
||||
conf=self.args.show_conf,
|
||||
labels=self.args.show_labels)
|
||||
plot_args = {
|
||||
'line_width': self.args.line_width,
|
||||
'boxes': self.args.boxes,
|
||||
'conf': self.args.show_conf,
|
||||
'labels': self.args.show_labels}
|
||||
if not self.args.retina_masks:
|
||||
plot_args['im_gpu'] = im[idx]
|
||||
self.plotted_img = result.plot(**plot_args)
|
||||
@ -214,17 +217,23 @@ class BasePredictor:
|
||||
# Setup model
|
||||
if not self.model:
|
||||
self.setup_model(model)
|
||||
|
||||
# Setup source every time predict is called
|
||||
self.setup_source(source if source is not None else self.args.source)
|
||||
|
||||
# Check if save_dir/ label file exists
|
||||
if self.args.save or self.args.save_txt:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Warmup model
|
||||
if not self.done_warmup:
|
||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
|
||||
self.done_warmup = True
|
||||
|
||||
# Checks
|
||||
if self.source_type.tensor and (self.args.save or self.args.save_txt or self.args.show):
|
||||
LOGGER.warning("WARNING ⚠️ 'save', 'save_txt' and 'show' arguments not enabled for torch.Tensor inference.")
|
||||
|
||||
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
self.run_callbacks('on_predict_start')
|
||||
for batch in self.dataset:
|
||||
@ -255,11 +264,7 @@ class BasePredictor:
|
||||
'preprocess': profilers[0].dt * 1E3 / n,
|
||||
'inference': profilers[1].dt * 1E3 / n,
|
||||
'postprocess': profilers[2].dt * 1E3 / n}
|
||||
if self.source_type.tensor: # skip write, show and plot operations if input is raw tensor
|
||||
if self.args.save or self.args.save_txt or self.args.show:
|
||||
LOGGER.warning('WARNING ⚠️ save, save_txt and show argument not enabled for tensor inference.')
|
||||
continue
|
||||
p, im0 = path[i], im0s[i].copy()
|
||||
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
|
||||
p = Path(p)
|
||||
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
@ -286,7 +291,7 @@ class BasePredictor:
|
||||
if self.args.verbose and self.seen:
|
||||
t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
|
||||
LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
|
||||
f'{(1, 3, *self.imgsz)}' % t)
|
||||
f'{(1, 3, *im.shape[2:])}' % t)
|
||||
if self.args.save or self.args.save_txt or self.args.save_crop:
|
||||
nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
|
||||
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||
|
@ -198,6 +198,10 @@ class Results(SimpleClass):
|
||||
Returns:
|
||||
(numpy.ndarray): A numpy array of the annotated image.
|
||||
"""
|
||||
if img is None and isinstance(self.orig_img, torch.Tensor):
|
||||
LOGGER.warning('WARNING ⚠️ Results plotting is not supported for torch.Tensor image types.')
|
||||
return
|
||||
|
||||
# Deprecation warn TODO: remove in 8.2
|
||||
if 'show_conf' in kwargs:
|
||||
deprecation_warn('show_conf', 'conf')
|
||||
@ -305,7 +309,7 @@ class Results(SimpleClass):
|
||||
file_name (str | pathlib.Path): File name.
|
||||
"""
|
||||
if self.probs is not None:
|
||||
LOGGER.warning('Warning: Classify task do not support `save_crop`.')
|
||||
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
|
||||
return
|
||||
if isinstance(save_dir, str):
|
||||
save_dir = Path(save_dir)
|
||||
|
Reference in New Issue
Block a user