ultralytics 8.0.40 TensorRT metadata and Results visualizer (#1014)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Bogdan Gheorghe <112427971+bogdan-galileo@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Jaap van de Loosdrecht <jaap@vdlmv.nl>
Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com>
Co-authored-by: nerdyespresso <106761627+nerdyespresso@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-02-17 20:06:06 +01:00
committed by GitHub
parent e799592718
commit 9047d737f4
40 changed files with 576 additions and 280 deletions

View File

@ -24,9 +24,12 @@ def check_class_names(names):
# Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts.
if isinstance(names, list): # names is a list
names = dict(enumerate(names)) # convert to dict
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
map = yaml_load(ROOT / 'yolo/data/datasets/ImageNet.yaml')['map'] # human-readable names
names = {k: map[v] for k, v in names.items()}
if isinstance(names, dict):
if not all(isinstance(k, int) for k in names.keys()): # convert string keys to int, i.e. '0' to 0
names = {int(k): v for k, v in names.items()}
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
map = yaml_load(ROOT / 'yolo/data/datasets/ImageNet.yaml')['map'] # human-readable names
names = {k: map[v] for k, v in names.items()}
return names
@ -129,7 +132,6 @@ class AutoBackend(nn.Module):
if batch_dim.is_static:
batch_size = batch_dim.get_length()
executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
@ -138,7 +140,14 @@ class AutoBackend(nn.Module):
device = torch.device('cuda:0')
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO)
# Read file
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
# Read metadata length
meta_len = int.from_bytes(f.read(4), byteorder='little')
# Read metadata
meta = json.loads(f.read(meta_len).decode('utf-8'))
stride, names = int(meta['stride']), meta['names']
# Read engine
model = runtime.deserialize_cuda_engine(f.read())
context = model.create_execution_context()
bindings = OrderedDict()
@ -216,7 +225,7 @@ class AutoBackend(nn.Module):
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
stride, names = int(meta['stride']), meta['names']
elif tfjs: # TF.js
raise NotImplementedError('ERROR: YOLOv8 TF.js inference is not supported')
raise NotImplementedError('YOLOv8 TF.js inference is not supported')
elif paddle: # PaddlePaddle
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
@ -245,7 +254,16 @@ class AutoBackend(nn.Module):
"See https://docs.ultralytics.com/tasks/detection/#export for help."
f"\n\n{EXPORT_FORMATS_TABLE}")
# class names
# Load external metadata YAML
if xml or saved_model or paddle:
metadata = Path(w).parent / 'metadata.yaml'
if metadata.exists():
metadata = yaml_load(metadata)
stride, names = int(metadata['stride']), metadata['names'] # load metadata
else:
LOGGER.warning(f"WARNING ⚠️ Metadata not found at '{metadata}'")
# Check names
if 'names' not in locals(): # names missing
names = yaml_load(check_yaml(data))['names'] if data else {i: f'class{i}' for i in range(999)} # assign
names = check_class_names(names)
@ -340,7 +358,7 @@ class AutoBackend(nn.Module):
if len(self.output_details) == 2: # segment
y = [y[1], np.transpose(y[0], (0, 3, 1, 2))]
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
# y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
if isinstance(y, (list, tuple)):
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
@ -394,18 +412,3 @@ class AutoBackend(nn.Module):
types[8] &= not types[9] # tflite &= not edgetpu
triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
return types + [triton]
@staticmethod
def _load_metadata(f=Path('path/to/meta.yaml')):
"""
Loads the metadata from a yaml file
Args:
f: The path to the metadata file.
"""
# Load metadata from meta.yaml if it exists
if f.exists():
d = yaml_load(f)
return d['stride'], d['names'] # assign stride, names
return None, None

View File

@ -248,6 +248,9 @@ class SegmentationModel(DetectionModel):
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
super().__init__(cfg, ch, nc, verbose)
def _forward_augment(self, x):
raise NotImplementedError("WARNING ⚠️ SegmentationModel has not supported augment inference yet!")
class ClassificationModel(BaseModel):
# YOLOv8 classification model