ultralytics 8.0.40 TensorRT metadata and Results visualizer (#1014)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Bogdan Gheorghe <112427971+bogdan-galileo@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Jaap van de Loosdrecht <jaap@vdlmv.nl>
Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com>
Co-authored-by: nerdyespresso <106761627+nerdyespresso@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-02-17 20:06:06 +01:00
committed by GitHub
parent e799592718
commit 9047d737f4
40 changed files with 576 additions and 280 deletions

View File

@ -18,8 +18,8 @@ TensorFlow.js | `tfjs` | yolov8n_web_model/
PaddlePaddle | `paddle` | yolov8n_paddle_model/
Requirements:
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
$ pip install -r requirements.txt coremltools onnx onnxsim onnxruntime openvino-dev tensorflow-cpu # CPU
$ pip install -r requirements.txt coremltools onnx onnxsim onnxruntime-gpu openvino-dev tensorflow # GPU
Python:
from ultralytics import YOLO
@ -69,13 +69,14 @@ from ultralytics.nn.tasks import DetectionModel, SegmentationModel
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
from ultralytics.yolo.data.utils import IMAGENET_MEAN, IMAGENET_STD, check_det_dataset
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, __version__, callbacks, colorstr, get_default_args, yaml_save
from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, WINDOWS, __version__, callbacks, colorstr,
get_default_args, yaml_save)
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
MACOS = platform.system() == 'Darwin' # macOS environment
CUDA = torch.cuda.is_available()
def export_formats():
@ -229,27 +230,24 @@ class Exporter:
if coreml: # CoreML
f[4], _ = self._export_coreml()
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export support is still under development. '
LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export is still under development. '
'Please consider contributing to the effort if you have TF expertise. Thank you!')
nms = False
f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
agnostic_nms=self.args.agnostic_nms or tfjs)
debug = False
if debug:
if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = self._export_pb(s_model)
if tflite or edgetpu:
f[7], _ = self._export_tflite(s_model,
int8=self.args.int8 or edgetpu,
data=self.args.data,
nms=nms,
agnostic_nms=self.args.agnostic_nms)
if edgetpu:
f[8], _ = self._export_edgetpu()
self._add_tflite_metadata(f[8] or f[7])
if tfjs:
f[9], _ = self._export_tfjs()
if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = self._export_pb(s_model)
if tflite or edgetpu:
f[7] = str(Path(f[5]) / (self.file.stem + '_float16.tflite'))
# f[7], _ = self._export_tflite(s_model,
# int8=self.args.int8 or edgetpu,
# data=self.args.data,
# nms=nms,
# agnostic_nms=self.args.agnostic_nms)
if edgetpu:
f[8], _ = self._export_edgetpu(tflite_model=f[7])
if tfjs:
f[9], _ = self._export_tfjs()
if paddle: # PaddlePaddle
f[10], _ = self._export_paddle()
@ -258,13 +256,14 @@ class Exporter:
if any(f):
f = str(Path(f[-1]))
square = self.imgsz[0] == self.imgsz[1]
s = f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not work. Use " \
f"export 'imgsz={max(self.imgsz)}' if val is required." if not square else ''
s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
data = f"data={self.args.data}" if model.task == 'segment' and format == 'pb' else ''
LOGGER.info(
f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz}"
f"\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}"
f"\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}"
f"\nVisualize: https://netron.app")
@ -335,7 +334,7 @@ class Exporter:
check_requirements('onnxsim')
import onnxsim
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
subprocess.run(f'onnxsim {f} {f}', shell=True)
except Exception as e:
LOGGER.info(f'{prefix} simplifier failure: {e}')
@ -358,7 +357,7 @@ class Exporter:
framework="onnx",
compress_to_fp16=self.args.half) # export
ov.serialize(ov_model, f_ov) # save
yaml_save(Path(f) / self.file.with_suffix('.yaml').name, self.metadata) # add metadata.yaml
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
return f, None
@try_export
@ -372,7 +371,7 @@ class Exporter:
f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}')
pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export
yaml_save(Path(f) / self.file.with_suffix('.yaml').name, self.metadata) # add metadata.yaml
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
return f, None
@try_export
@ -436,7 +435,7 @@ class Exporter:
try:
import tensorrt as trt # noqa
except ImportError:
if platform.system() == 'Linux':
if LINUX:
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
import tensorrt as trt # noqa
@ -482,8 +481,16 @@ class Exporter:
f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}')
if builder.platform_has_fast_fp16 and self.args.half:
config.set_flag(trt.BuilderFlag.FP16)
# Write file
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
# Metadata
meta = json.dumps(self.metadata)
t.write(len(meta).to_bytes(4, byteorder='little', signed=True))
t.write(meta.encode())
# Model
t.write(engine.serialize())
return f, None
@try_export
@ -500,10 +507,10 @@ class Exporter:
try:
import tensorflow as tf # noqa
except ImportError:
check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
import tensorflow as tf # noqa
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon", "tflite_support"),
cmds="--extra-index-url https://pypi.ngc.nvidia.com ")
cmds="--extra-index-url https://pypi.ngc.nvidia.com")
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = str(self.file).replace(self.file.suffix, '_saved_model')
@ -514,10 +521,11 @@ class Exporter:
# Export to TF SavedModel
subprocess.run(f'onnx2tf -i {onnx} -o {f} --non_verbose', shell=True)
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
# Add TFLite metadata
for tflite_file in Path(f).rglob('*.tflite'):
self._add_tflite_metadata(tflite_file)
for file in Path(f).rglob('*.tflite'):
self._add_tflite_metadata(file)
# Load saved_model
keras_model = tf.saved_model.load(f, tags=None, options=None)
@ -537,7 +545,7 @@ class Exporter:
try:
import tensorflow as tf # noqa
except ImportError:
check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
import tensorflow as tf # noqa
# from models.tf import TFModel
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
@ -628,11 +636,11 @@ class Exporter:
return f, None
@try_export
def _export_edgetpu(self, prefix=colorstr('Edge TPU:')):
def _export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
# YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
cmd = 'edgetpu_compiler --version'
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
assert LINUX, f'export only supported on Linux. See {help_url}'
if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
@ -646,11 +654,11 @@ class Exporter:
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
f = str(self.file).replace(self.file.suffix, '-int8_edgetpu.tflite') # Edge TPU model
f_tfl = str(self.file).replace(self.file.suffix, '-int8.tflite') # TFLite model
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {f_tfl}"
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}"
subprocess.run(cmd.split(), check=True)
self._add_tflite_metadata(f)
return f, None
@try_export
@ -681,6 +689,7 @@ class Exporter:
f_json.read_text(),
)
j.write(subst)
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
return f, None
def _add_tflite_metadata(self, file):
@ -736,14 +745,6 @@ class Exporter:
populator.populate()
tmp_file.unlink()
# TODO Rename this here and in `_add_tflite_metadata`
def _extracted_from__add_tflite_metadata_15(self, _metadata_fb, arg1, arg2):
# Creates input info.
result = _metadata_fb.TensorMetadataT()
result.name = arg1
result.description = arg2
return result
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
# YOLOv8 CoreML pipeline
import coremltools as ct # noqa

View File

@ -42,6 +42,7 @@ class YOLO:
model (str, Path): model to load or create
type (str): Type/version of models to use. Defaults to "v8".
"""
self._reset_callbacks()
self.type = type
self.ModelClass = None # model class
self.TrainerClass = None # trainer class
@ -307,3 +308,8 @@ class YOLO:
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset':
args.pop(arg, None)
@staticmethod
def _reset_callbacks():
for event in callbacks.default_callbacks.keys():
callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]]

View File

@ -85,7 +85,6 @@ class BasePredictor:
self.data = self.args.data # data_dict
self.imgsz = None
self.device = None
self.classes = self.args.classes
self.dataset = None
self.vid_path, self.vid_writer = None, None
self.annotator = None
@ -103,7 +102,7 @@ class BasePredictor:
def write_results(self, results, batch, print_string):
raise NotImplementedError("print_results function needs to be implemented")
def postprocess(self, preds, img, orig_img, classes=None):
def postprocess(self, preds, img, orig_img):
return preds
@smart_inference_mode()
@ -170,13 +169,13 @@ class BasePredictor:
# postprocess
with self.dt[2]:
self.results = self.postprocess(preds, im, im0s, self.classes)
self.results = self.postprocess(preds, im, im0s)
self.run_callbacks("on_predict_postprocess_end")
# visualize, save, write results
for i in range(len(im)):
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img else (path,
im0s)
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \
else (path, im0s.copy())
p = Path(p)
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:

View File

@ -1,9 +1,13 @@
from copy import deepcopy
from functools import lru_cache
import numpy as np
import torch
import torchvision.transforms.functional as F
from PIL import Image
from ultralytics.yolo.utils import LOGGER, ops
from ultralytics.yolo.utils.plotting import Annotator, colors
class Results:
@ -14,22 +18,24 @@ class Results:
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
masks (Masks, optional): A Masks object containing the detection masks.
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
orig_shape (tuple, optional): Original image size.
orig_img (tuple, optional): Original image size.
Attributes:
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
masks (Masks, optional): A Masks object containing the detection masks.
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
orig_shape (tuple, optional): Original image size.
orig_img (tuple, optional): Original image size.
data (torch.Tensor): The raw masks tensor
"""
def __init__(self, boxes=None, masks=None, probs=None, orig_shape=None) -> None:
self.boxes = Boxes(boxes, orig_shape) if boxes is not None else None # native size boxes
self.masks = Masks(masks, orig_shape) if masks is not None else None # native size or imgsz masks
def __init__(self, boxes=None, masks=None, probs=None, orig_img=None, names=None) -> None:
self.orig_img = orig_img
self.orig_shape = orig_img.shape[:2]
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
self.probs = probs if probs is not None else None
self.orig_shape = orig_shape
self.names = names
self.comp = ["boxes", "masks", "probs"]
def pandas(self):
@ -37,7 +43,7 @@ class Results:
# TODO masks.pandas + boxes.pandas + cls.pandas
def __getitem__(self, idx):
r = Results(orig_shape=self.orig_shape)
r = Results(orig_img=self.orig_img)
for item in self.comp:
if getattr(self, item) is None:
continue
@ -53,7 +59,7 @@ class Results:
self.probs = probs
def cpu(self):
r = Results(orig_shape=self.orig_shape)
r = Results(orig_img=self.orig_img)
for item in self.comp:
if getattr(self, item) is None:
continue
@ -61,7 +67,7 @@ class Results:
return r
def numpy(self):
r = Results(orig_shape=self.orig_shape)
r = Results(orig_img=self.orig_img)
for item in self.comp:
if getattr(self, item) is None:
continue
@ -69,7 +75,7 @@ class Results:
return r
def cuda(self):
r = Results(orig_shape=self.orig_shape)
r = Results(orig_img=self.orig_img)
for item in self.comp:
if getattr(self, item) is None:
continue
@ -77,7 +83,7 @@ class Results:
return r
def to(self, *args, **kwargs):
r = Results(orig_shape=self.orig_shape)
r = Results(orig_img=self.orig_img)
for item in self.comp:
if getattr(self, item) is None:
continue
@ -118,6 +124,40 @@ class Results:
orig_shape (tuple, optional): Original image size.
""")
def visualize(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
"""
Plots the given result on an input RGB image. Accepts cv2(numpy) or PIL Image
Args:
show_conf (bool): Show confidence
line_width (Float): The line width of boxes. Automatically scaled to img size if not provided
font_size (Float): The font size of . Automatically scaled to img size if not provided
"""
img = deepcopy(self.orig_img)
annotator = Annotator(img, line_width, font_size, font, pil, example)
boxes = self.boxes
masks = self.masks.data
logits = self.probs
names = self.names
if boxes is not None:
for d in reversed(boxes):
cls, conf = d.cls.squeeze(), d.conf.squeeze()
c = int(cls)
label = (f'{names[c]}' if names else f'{c}') + (f'{conf:.2f}' if show_conf else '')
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
if masks is not None:
im_gpu = torch.as_tensor(img, dtype=torch.float16).permute(2, 0, 1).flip(0).contiguous()
im_gpu = F.resize(im_gpu, masks.data.shape[1:]) / 255
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im_gpu)
if logits is not None:
top5i = logits.argsort(0, descending=True)[:5].tolist() # top 5 indices
text = f"{', '.join(f'{names[j] if names else j} {logits[j]:.2f}' for j in top5i)}, "
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
return img
class Boxes:
"""