Warn `save` disabled for torch inference (#3361)

single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 2ebd808b69
commit 5381fc8a58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -294,10 +294,11 @@ class LoadPilAndNumpy:
class LoadTensor:
def __init__(self, imgs) -> None:
self.im0 = imgs
self.bs = imgs.shape[0]
def __init__(self, im0) -> None:
self.im0 = im0
self.bs = im0.shape[0]
self.mode = 'image'
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
def __iter__(self):
"""Returns an iterator object."""
@ -309,7 +310,7 @@ class LoadTensor:
if self.count == 1:
raise StopIteration
self.count += 1
return None, self.im0, None, '' # self.paths, im, self.im0, None, ''
return self.paths, self.im0, None, ''
def __len__(self):
"""Returns the batch size."""

@ -147,7 +147,6 @@ class BasePredictor:
log_string = ''
if len(im.shape) == 3:
im = im[None] # expand for batch dim
self.seen += 1
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
log_string += f'{idx}: '
frame = self.dataset.count
@ -251,11 +250,14 @@ class BasePredictor:
# Visualize, save, write results
n = len(im0s)
for i in range(n):
self.seen += 1
self.results[i].speed = {
'preprocess': profilers[0].dt * 1E3 / n,
'inference': profilers[1].dt * 1E3 / n,
'postprocess': profilers[2].dt * 1E3 / n}
if self.source_type.tensor: # skip write, show and plot operations if input is raw tensor
if self.args.save or self.args.save_txt or self.args.show:
LOGGER.warning('WARNING ⚠️ save, save_txt and show argument not enabled for tensor inference.')
continue
p, im0 = path[i], im0s[i].copy()
p = Path(p)

Loading…
Cancel
Save