From 5381fc8a584e9fa3eabae4ac84571880244a0b32 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 24 Jun 2023 15:57:17 +0200 Subject: [PATCH] Warn `save` disabled for torch inference (#3361) --- ultralytics/yolo/data/dataloaders/stream_loaders.py | 9 +++++---- ultralytics/yolo/engine/predictor.py | 4 +++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/ultralytics/yolo/data/dataloaders/stream_loaders.py b/ultralytics/yolo/data/dataloaders/stream_loaders.py index 400cee6..ba18296 100644 --- a/ultralytics/yolo/data/dataloaders/stream_loaders.py +++ b/ultralytics/yolo/data/dataloaders/stream_loaders.py @@ -294,10 +294,11 @@ class LoadPilAndNumpy: class LoadTensor: - def __init__(self, imgs) -> None: - self.im0 = imgs - self.bs = imgs.shape[0] + def __init__(self, im0) -> None: + self.im0 = im0 + self.bs = im0.shape[0] self.mode = 'image' + self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)] def __iter__(self): """Returns an iterator object.""" @@ -309,7 +310,7 @@ class LoadTensor: if self.count == 1: raise StopIteration self.count += 1 - return None, self.im0, None, '' # self.paths, im, self.im0, None, '' + return self.paths, self.im0, None, '' def __len__(self): """Returns the batch size.""" diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index b71785c..8cb6b87 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -147,7 +147,6 @@ class BasePredictor: log_string = '' if len(im.shape) == 3: im = im[None] # expand for batch dim - self.seen += 1 if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 log_string += f'{idx}: ' frame = self.dataset.count @@ -251,11 +250,14 @@ class BasePredictor: # Visualize, save, write results n = len(im0s) for i in range(n): + self.seen += 1 self.results[i].speed = { 'preprocess': profilers[0].dt * 1E3 / n, 'inference': profilers[1].dt * 1E3 / n, 'postprocess': profilers[2].dt * 1E3 / n} if self.source_type.tensor: # skip write, show and plot operations if input is raw tensor + if self.args.save or self.args.save_txt or self.args.show: + LOGGER.warning('WARNING ⚠️ save, save_txt and show argument not enabled for tensor inference.') continue p, im0 = path[i], im0s[i].copy() p = Path(p)