|
|
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
|
|
|
|
|
|
import ast
|
|
|
|
import contextlib
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import platform
|
|
|
|
import zipfile
|
|
|
|
from collections import OrderedDict, namedtuple
|
|
|
|
from pathlib import Path
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
from ultralytics.yolo.utils import ARM64, LINUX, LOGGER, ROOT, yaml_load
|
|
|
|
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version, check_yaml
|
|
|
|
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
|
|
|
|
from ultralytics.yolo.utils.ops import xywh2xyxy
|
|
|
|
|
|
|
|
|
|
|
|
def check_class_names(names):
|
|
|
|
"""Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts."""
|
|
|
|
if isinstance(names, list): # names is a list
|
|
|
|
names = dict(enumerate(names)) # convert to dict
|
|
|
|
if isinstance(names, dict):
|
|
|
|
# Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
|
|
|
|
names = {int(k): str(v) for k, v in names.items()}
|
|
|
|
n = len(names)
|
|
|
|
if max(names.keys()) >= n:
|
|
|
|
raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
|
|
|
|
f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
|
|
|
|
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
|
|
|
|
map = yaml_load(ROOT / 'datasets/ImageNet.yaml')['map'] # human-readable names
|
|
|
|
names = {k: map[v] for k, v in names.items()}
|
|
|
|
return names
|
|
|
|
|
|
|
|
|
|
|
|
class AutoBackend(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
weights='yolov8n.pt',
|
|
|
|
device=torch.device('cpu'),
|
|
|
|
dnn=False,
|
|
|
|
data=None,
|
|
|
|
fp16=False,
|
|
|
|
fuse=True,
|
|
|
|
verbose=True):
|
|
|
|
"""
|
|
|
|
MultiBackend class for python inference on various platforms using Ultralytics YOLO.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
weights (str): The path to the weights file. Default: 'yolov8n.pt'
|
|
|
|
device (torch.device): The device to run the model on.
|
|
|
|
dnn (bool): Use OpenCV DNN module for inference if True, defaults to False.
|
|
|
|
data (str | Path | optional): Additional data.yaml file for class names.
|
|
|
|
fp16 (bool): If True, use half precision. Default: False
|
|
|
|
fuse (bool): Whether to fuse the model or not. Default: True
|
|
|
|
verbose (bool): Whether to run in verbose mode or not. Default: True
|
|
|
|
|
|
|
|
Supported formats and their naming conventions:
|
|
|
|
| Format | Suffix |
|
|
|
|
|-----------------------|------------------|
|
|
|
|
| PyTorch | *.pt |
|
|
|
|
| TorchScript | *.torchscript |
|
|
|
|
| ONNX Runtime | *.onnx |
|
|
|
|
| ONNX OpenCV DNN | *.onnx dnn=True |
|
|
|
|
| OpenVINO | *.xml |
|
|
|
|
| CoreML | *.mlmodel |
|
|
|
|
| TensorRT | *.engine |
|
|
|
|
| TensorFlow SavedModel | *_saved_model |
|
|
|
|
| TensorFlow GraphDef | *.pb |
|
|
|
|
| TensorFlow Lite | *.tflite |
|
|
|
|
| TensorFlow Edge TPU | *_edgetpu.tflite |
|
|
|
|
| PaddlePaddle | *_paddle_model |
|
|
|
|
| ncnn | *_ncnn_model |
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
w = str(weights[0] if isinstance(weights, list) else weights)
|
|
|
|
nn_module = isinstance(weights, torch.nn.Module)
|
|
|
|
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
|
|
|
|
self._model_type(w)
|
|
|
|
fp16 &= pt or jit or onnx or engine or nn_module or triton # FP16
|
|
|
|
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
|
|
|
stride = 32 # default stride
|
|
|
|
model, metadata = None, None
|
|
|
|
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
|
|
|
if not (pt or triton or nn_module):
|
|
|
|
w = attempt_download_asset(w) # download if not local
|
|
|
|
|
|
|
|
# NOTE: special case: in-memory pytorch model
|
|
|
|
if nn_module:
|
|
|
|
model = weights.to(device)
|
|
|
|
model = model.fuse(verbose=verbose) if fuse else model
|
|
|
|
if hasattr(model, 'kpt_shape'):
|
|
|
|
kpt_shape = model.kpt_shape # pose-only
|
|
|
|
stride = max(int(model.stride.max()), 32) # model stride
|
|
|
|
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
|
|
|
model.half() if fp16 else model.float()
|
|
|
|
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
|
|
|
pt = True
|
|
|
|
elif pt: # PyTorch
|
|
|
|
from ultralytics.nn.tasks import attempt_load_weights
|
|
|
|
model = attempt_load_weights(weights if isinstance(weights, list) else w,
|
|
|
|
device=device,
|
|
|
|
inplace=True,
|
|
|
|
fuse=fuse)
|
|
|
|
if hasattr(model, 'kpt_shape'):
|
|
|
|
kpt_shape = model.kpt_shape # pose-only
|
|
|
|
stride = max(int(model.stride.max()), 32) # model stride
|
|
|
|
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
|
|
|
model.half() if fp16 else model.float()
|
|
|
|
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
|
|
|
elif jit: # TorchScript
|
|
|
|
LOGGER.info(f'Loading {w} for TorchScript inference...')
|
|
|
|
extra_files = {'config.txt': ''} # model metadata
|
|
|
|
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
|
|
|
|
model.half() if fp16 else model.float()
|
|
|
|
if extra_files['config.txt']: # load metadata dict
|
|
|
|
metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items()))
|
|
|
|
elif dnn: # ONNX OpenCV DNN
|
|
|
|
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
|
|
|
|
check_requirements('opencv-python>=4.5.4')
|
|
|
|
net = cv2.dnn.readNetFromONNX(w)
|
|
|
|
elif onnx: # ONNX Runtime
|
|
|
|
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
|
|
|
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
|
|
|
import onnxruntime
|
|
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
|
|
|
session = onnxruntime.InferenceSession(w, providers=providers)
|
|
|
|
output_names = [x.name for x in session.get_outputs()]
|
|
|
|
metadata = session.get_modelmeta().custom_metadata_map # metadata
|
|
|
|
elif xml: # OpenVINO
|
|
|
|
LOGGER.info(f'Loading {w} for OpenVINO inference...')
|
|
|
|
check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
|
|
|
from openvino.runtime import Core, Layout, get_batch # noqa
|
|
|
|
ie = Core()
|
|
|
|
w = Path(w)
|
|
|
|
if not w.is_file(): # if not *.xml
|
|
|
|
w = next(w.glob('*.xml')) # get *.xml file from *_openvino_model dir
|
|
|
|
network = ie.read_model(model=str(w), weights=w.with_suffix('.bin'))
|
|
|
|
if network.get_parameters()[0].get_layout().empty:
|
|
|
|
network.get_parameters()[0].set_layout(Layout('NCHW'))
|
|
|
|
batch_dim = get_batch(network)
|
|
|
|
if batch_dim.is_static:
|
|
|
|
batch_size = batch_dim.get_length()
|
|
|
|
executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for NCS2
|
|
|
|
metadata = w.parent / 'metadata.yaml'
|
|
|
|
elif engine: # TensorRT
|
|
|
|
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
|
|
|
try:
|
|
|
|
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
|
|
|
|
except ImportError:
|
|
|
|
if LINUX:
|
|
|
|
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
|
|
|
|
import tensorrt as trt # noqa
|
|
|
|
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
|
|
|
if device.type == 'cpu':
|
|
|
|
device = torch.device('cuda:0')
|
|
|
|
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
|
|
|
logger = trt.Logger(trt.Logger.INFO)
|
|
|
|
# Read file
|
|
|
|
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
|
|
|
meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length
|
|
|
|
metadata = json.loads(f.read(meta_len).decode('utf-8')) # read metadata
|
|
|
|
model = runtime.deserialize_cuda_engine(f.read()) # read engine
|
|
|
|
context = model.create_execution_context()
|
|
|
|
bindings = OrderedDict()
|
|
|
|
output_names = []
|
|
|
|
fp16 = False # default updated below
|
|
|
|
dynamic = False
|
|
|
|
for i in range(model.num_bindings):
|
|
|
|
name = model.get_binding_name(i)
|
|
|
|
dtype = trt.nptype(model.get_binding_dtype(i))
|
|
|
|
if model.binding_is_input(i):
|
|
|
|
if -1 in tuple(model.get_binding_shape(i)): # dynamic
|
|
|
|
dynamic = True
|
|
|
|
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
|
|
|
|
if dtype == np.float16:
|
|
|
|
fp16 = True
|
|
|
|
else: # output
|
|
|
|
output_names.append(name)
|
|
|
|
shape = tuple(context.get_binding_shape(i))
|
|
|
|
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
|
|
|
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
|
|
|
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
|
|
|
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
|
|
|
elif coreml: # CoreML
|
|
|
|
LOGGER.info(f'Loading {w} for CoreML inference...')
|
|
|
|
import coremltools as ct
|
|
|
|
model = ct.models.MLModel(w)
|
|
|
|
metadata = dict(model.user_defined_metadata)
|
|
|
|
elif saved_model: # TF SavedModel
|
|
|
|
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
|
|
|
|
import tensorflow as tf
|
|
|
|
keras = False # assume TF1 saved_model
|
|
|
|
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
|
|
|
metadata = Path(w) / 'metadata.yaml'
|
|
|
|
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
|
|
|
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
from ultralytics.yolo.engine.exporter import gd_outputs
|
|
|
|
|
|
|
|
def wrap_frozen_graph(gd, inputs, outputs):
|
|
|
|
"""Wrap frozen graphs for deployment."""
|
|
|
|
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
|
|
|
|
ge = x.graph.as_graph_element
|
|
|
|
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
|
|
|
|
|
|
|
gd = tf.Graph().as_graph_def() # TF GraphDef
|
|
|
|
with open(w, 'rb') as f:
|
|
|
|
gd.ParseFromString(f.read())
|
|
|
|
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
|
|
|
|
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
|
|
|
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
|
|
|
from tflite_runtime.interpreter import Interpreter, load_delegate
|
|
|
|
except ImportError:
|
|
|
|
import tensorflow as tf
|
|
|
|
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
|
|
|
|
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
|
|
|
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
|
|
|
delegate = {
|
|
|
|
'Linux': 'libedgetpu.so.1',
|
|
|
|
'Darwin': 'libedgetpu.1.dylib',
|
|
|
|
'Windows': 'edgetpu.dll'}[platform.system()]
|
|
|
|
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
|
|
|
|
else: # TFLite
|
|
|
|
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
|
|
|
|
interpreter = Interpreter(model_path=w) # load TFLite model
|
|
|
|
interpreter.allocate_tensors() # allocate
|
|
|
|
input_details = interpreter.get_input_details() # inputs
|
|
|
|
output_details = interpreter.get_output_details() # outputs
|
|
|
|
# Load metadata
|
|
|
|
with contextlib.suppress(zipfile.BadZipFile):
|
|
|
|
with zipfile.ZipFile(w, 'r') as model:
|
|
|
|
meta_file = model.namelist()[0]
|
|
|
|
metadata = ast.literal_eval(model.read(meta_file).decode('utf-8'))
|
|
|
|
elif tfjs: # TF.js
|
|
|
|
raise NotImplementedError('YOLOv8 TF.js inference is not currently supported.')
|
|
|
|
elif paddle: # PaddlePaddle
|
|
|
|
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
|
|
|
|
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
|
|
|
|
import paddle.inference as pdi # noqa
|
|
|
|
w = Path(w)
|
|
|
|
if not w.is_file(): # if not *.pdmodel
|
|
|
|
w = next(w.rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
|
|
|
|
config = pdi.Config(str(w), str(w.with_suffix('.pdiparams')))
|
|
|
|
if cuda:
|
|
|
|
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
|
|
|
|
predictor = pdi.create_predictor(config)
|
|
|
|
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
|
|
|
output_names = predictor.get_output_names()
|
|
|
|
metadata = w.parents[1] / 'metadata.yaml'
|
|
|
|
elif ncnn: # ncnn
|
|
|
|
LOGGER.info(f'Loading {w} for ncnn inference...')
|
|
|
|
check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires NCNN
|
|
|
|
import ncnn as pyncnn
|
|
|
|
net = pyncnn.Net()
|
|
|
|
net.opt.num_threads = os.cpu_count()
|
|
|
|
net.opt.use_vulkan_compute = cuda
|
|
|
|
w = Path(w)
|
|
|
|
if not w.is_file(): # if not *.param
|
|
|
|
w = next(w.glob('*.param')) # get *.param file from *_ncnn_model dir
|
|
|
|
net.load_param(str(w))
|
|
|
|
net.load_model(str(w.with_suffix('.bin')))
|
|
|
|
metadata = w.parent / 'metadata.yaml'
|
|
|
|
elif triton: # NVIDIA Triton Inference Server
|
|
|
|
LOGGER.info('Triton Inference Server not supported...')
|
|
|
|
'''
|
|
|
|
TODO:
|
|
|
|
check_requirements('tritonclient[all]')
|
|
|
|
from utils.triton import TritonRemoteModel
|
|
|
|
model = TritonRemoteModel(url=w)
|
|
|
|
nhwc = model.runtime.startswith("tensorflow")
|
|
|
|
'''
|
|
|
|
else:
|
|
|
|
from ultralytics.yolo.engine.exporter import export_formats
|
|
|
|
raise TypeError(f"model='{w}' is not a supported model format. "
|
|
|
|
'See https://docs.ultralytics.com/modes/predict for help.'
|
|
|
|
f'\n\n{export_formats()}')
|
|
|
|
|
|
|
|
# Load external metadata YAML
|
|
|
|
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
|
|
|
|
metadata = yaml_load(metadata)
|
|
|
|
if metadata:
|
|
|
|
for k, v in metadata.items():
|
|
|
|
if k in ('stride', 'batch'):
|
|
|
|
metadata[k] = int(v)
|
|
|
|
elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str):
|
|
|
|
metadata[k] = eval(v)
|
|
|
|
stride = metadata['stride']
|
|
|
|
task = metadata['task']
|
|
|
|
batch = metadata['batch']
|
|
|
|
imgsz = metadata['imgsz']
|
|
|
|
names = metadata['names']
|
|
|
|
kpt_shape = metadata.get('kpt_shape')
|
|
|
|
elif not (pt or triton or nn_module):
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
|
|
|
|
|
|
|
|
# Check names
|
|
|
|
if 'names' not in locals(): # names missing
|
|
|
|
names = self._apply_default_class_names(data)
|
|
|
|
names = check_class_names(names)
|
|
|
|
|
|
|
|
self.__dict__.update(locals()) # assign all variables to self
|
|
|
|
|
|
|
|
def forward(self, im, augment=False, visualize=False):
|
|
|
|
"""
|
|
|
|
Runs inference on the YOLOv8 MultiBackend model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
im (torch.Tensor): The image tensor to perform inference on.
|
|
|
|
augment (bool): whether to perform data augmentation during inference, defaults to False
|
|
|
|
visualize (bool): whether to visualize the output predictions, defaults to False
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
|
|
|
|
"""
|
|
|
|
b, ch, h, w = im.shape # batch, channel, height, width
|
|
|
|
if self.fp16 and im.dtype != torch.float16:
|
|
|
|
im = im.half() # to FP16
|
|
|
|
if self.nhwc:
|
|
|
|
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
|
|
|
|
|
|
|
if self.pt or self.nn_module: # PyTorch
|
|
|
|
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
|
|
|
elif self.jit: # TorchScript
|
|
|
|
y = self.model(im)
|
|
|
|
elif self.dnn: # ONNX OpenCV DNN
|
|
|
|
im = im.cpu().numpy() # torch to numpy
|
|
|
|
self.net.setInput(im)
|
|
|
|
y = self.net.forward()
|
|
|
|
elif self.onnx: # ONNX Runtime
|
|
|
|
im = im.cpu().numpy() # torch to numpy
|
|
|
|
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
|
|
|
elif self.xml: # OpenVINO
|
|
|
|
im = im.cpu().numpy() # FP32
|
|
|
|
y = list(self.executable_network([im]).values())
|
|
|
|
elif self.engine: # TensorRT
|
|
|
|
if self.dynamic and im.shape != self.bindings['images'].shape:
|
|
|
|
i = self.model.get_binding_index('images')
|
|
|
|
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
|
|
|
|
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
|
|
|
|
for name in self.output_names:
|
|
|
|
i = self.model.get_binding_index(name)
|
|
|
|
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
|
|
|
|
s = self.bindings['images'].shape
|
|
|
|
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
|
|
|
self.binding_addrs['images'] = int(im.data_ptr())
|
|
|
|
self.context.execute_v2(list(self.binding_addrs.values()))
|
|
|
|
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
|
|
|
elif self.coreml: # CoreML
|
|
|
|
im = im[0].cpu().numpy()
|
|
|
|
im_pil = Image.fromarray((im * 255).astype('uint8'))
|
|
|
|
# im = im.resize((192, 320), Image.BILINEAR)
|
|
|
|
y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
|
|
|
|
if 'confidence' in y:
|
|
|
|
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
|
|
|
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
|
|
|
|
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
|
|
|
|
elif len(y) == 1: # classification model
|
|
|
|
y = list(y.values())
|
|
|
|
elif len(y) == 2: # segmentation model
|
|
|
|
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
|
|
|
|
elif self.paddle: # PaddlePaddle
|
|
|
|
im = im.cpu().numpy().astype(np.float32)
|
|
|
|
self.input_handle.copy_from_cpu(im)
|
|
|
|
self.predictor.run()
|
|
|
|
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
|
|
|
|
elif self.ncnn: # ncnn
|
|
|
|
im = (im[0] * 255.).cpu().numpy().astype(np.uint8)
|
|
|
|
im = np.ascontiguousarray(im.transpose(1, 2, 0))
|
|
|
|
mat_in = self.pyncnn.Mat.from_pixels(im, self.pyncnn.Mat.PixelType.PIXEL_RGB, *im.shape[:2])
|
|
|
|
mat_in.substract_mean_normalize([], [1 / 255.0, 1 / 255.0, 1 / 255.0])
|
|
|
|
ex = self.net.create_extractor()
|
|
|
|
input_names, output_names = self.net.input_names(), self.net.output_names()
|
|
|
|
ex.input(input_names[0], mat_in)
|
|
|
|
y = []
|
|
|
|
for output_name in output_names:
|
|
|
|
mat_out = self.pyncnn.Mat()
|
|
|
|
ex.extract(output_name, mat_out)
|
|
|
|
y.append(np.array(mat_out)[None])
|
|
|
|
elif self.triton: # NVIDIA Triton Inference Server
|
|
|
|
y = self.model(im)
|
|
|
|
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
|
|
|
im = im.cpu().numpy()
|
|
|
|
if self.saved_model: # SavedModel
|
|
|
|
y = self.model(im, training=False) if self.keras else self.model(im)
|
|
|
|
if not isinstance(y, list):
|
|
|
|
y = [y]
|
|
|
|
elif self.pb: # GraphDef
|
|
|
|
y = self.frozen_func(x=self.tf.constant(im))
|
|
|
|
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
|
|
|
|
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
|
|
|
|
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
|
|
|
self.names = {i: f'class{i}' for i in range(nc)}
|
|
|
|
else: # Lite or Edge TPU
|
|
|
|
input = self.input_details[0]
|
|
|
|
int8 = input['dtype'] == np.int8 # is TFLite quantized int8 model
|
|
|
|
if int8:
|
|
|
|
scale, zero_point = input['quantization']
|
|
|
|
im = (im / scale + zero_point).astype(np.int8) # de-scale
|
|
|
|
self.interpreter.set_tensor(input['index'], im)
|
|
|
|
self.interpreter.invoke()
|
|
|
|
y = []
|
|
|
|
for output in self.output_details:
|
|
|
|
x = self.interpreter.get_tensor(output['index'])
|
|
|
|
if int8:
|
|
|
|
scale, zero_point = output['quantization']
|
|
|
|
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
|
|
|
y.append(x)
|
|
|
|
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
|
|
|
|
if len(y) == 2: # segment with (det, proto) output order reversed
|
|
|
|
if len(y[1].shape) != 4:
|
|
|
|
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
|
|
|
|
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
|
|
|
|
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
|
|
|
|
# y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
|
|
|
|
|
|
|
# for x in y:
|
|
|
|
# print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes
|
|
|
|
if isinstance(y, (list, tuple)):
|
|
|
|
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
|
|
|
|
else:
|
|
|
|
return self.from_numpy(y)
|
|
|
|
|
|
|
|
def from_numpy(self, x):
|
|
|
|
"""
|
|
|
|
Convert a numpy array to a tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (np.ndarray): The array to be converted.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(torch.Tensor): The converted tensor
|
|
|
|
"""
|
|
|
|
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
|
|
|
|
|
|
|
|
def warmup(self, imgsz=(1, 3, 640, 640)):
|
|
|
|
"""
|
|
|
|
Warm up the model by running one forward pass with a dummy input.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(None): This method runs the forward pass and don't return any value
|
|
|
|
"""
|
|
|
|
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
|
|
|
|
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
|
|
|
|
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
|
|
|
for _ in range(2 if self.jit else 1): #
|
|
|
|
self.forward(im) # warmup
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _apply_default_class_names(data):
|
|
|
|
"""Applies default class names to an input YAML file or returns numerical class names."""
|
|
|
|
with contextlib.suppress(Exception):
|
|
|
|
return yaml_load(check_yaml(data))['names']
|
|
|
|
return {i: f'class{i}' for i in range(999)} # return default if above errors
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _model_type(p='path/to/model.pt'):
|
|
|
|
"""
|
|
|
|
This function takes a path to a model file and returns the model type
|
|
|
|
|
|
|
|
Args:
|
|
|
|
p: path to the model file. Defaults to path/to/model.pt
|
|
|
|
"""
|
|
|
|
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
|
|
|
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
|
|
|
from ultralytics.yolo.engine.exporter import export_formats
|
|
|
|
sf = list(export_formats().Suffix) # export suffixes
|
|
|
|
if not is_url(p, check=False) and not isinstance(p, str):
|
|
|
|
check_suffix(p, sf) # checks
|
|
|
|
url = urlparse(p) # if url may be Triton inference server
|
|
|
|
types = [s in Path(p).name for s in sf]
|
|
|
|
types[8] &= not types[9] # tflite &= not edgetpu
|
|
|
|
triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
|
|
|
|
return types + [triton]
|