ultralytics 8.0.40
TensorRT metadata and Results visualizer (#1014)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Bogdan Gheorghe <112427971+bogdan-galileo@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Jaap van de Loosdrecht <jaap@vdlmv.nl> Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com> Co-authored-by: nerdyespresso <106761627+nerdyespresso@users.noreply.github.com>
This commit is contained in:
@ -18,8 +18,8 @@ TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
|
||||
Requirements:
|
||||
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
|
||||
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
|
||||
$ pip install -r requirements.txt coremltools onnx onnxsim onnxruntime openvino-dev tensorflow-cpu # CPU
|
||||
$ pip install -r requirements.txt coremltools onnx onnxsim onnxruntime-gpu openvino-dev tensorflow # GPU
|
||||
|
||||
Python:
|
||||
from ultralytics import YOLO
|
||||
@ -69,13 +69,14 @@ from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||
from ultralytics.yolo.data.utils import IMAGENET_MEAN, IMAGENET_STD, check_det_dataset
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, __version__, callbacks, colorstr, get_default_args, yaml_save
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, WINDOWS, __version__, callbacks, colorstr,
|
||||
get_default_args, yaml_save)
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
|
||||
|
||||
MACOS = platform.system() == 'Darwin' # macOS environment
|
||||
CUDA = torch.cuda.is_available()
|
||||
|
||||
|
||||
def export_formats():
|
||||
@ -229,27 +230,24 @@ class Exporter:
|
||||
if coreml: # CoreML
|
||||
f[4], _ = self._export_coreml()
|
||||
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
||||
LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export support is still under development. '
|
||||
LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export is still under development. '
|
||||
'Please consider contributing to the effort if you have TF expertise. Thank you!')
|
||||
nms = False
|
||||
f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
|
||||
agnostic_nms=self.args.agnostic_nms or tfjs)
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
if pb or tfjs: # pb prerequisite to tfjs
|
||||
f[6], _ = self._export_pb(s_model)
|
||||
if tflite or edgetpu:
|
||||
f[7], _ = self._export_tflite(s_model,
|
||||
int8=self.args.int8 or edgetpu,
|
||||
data=self.args.data,
|
||||
nms=nms,
|
||||
agnostic_nms=self.args.agnostic_nms)
|
||||
if edgetpu:
|
||||
f[8], _ = self._export_edgetpu()
|
||||
self._add_tflite_metadata(f[8] or f[7])
|
||||
if tfjs:
|
||||
f[9], _ = self._export_tfjs()
|
||||
if pb or tfjs: # pb prerequisite to tfjs
|
||||
f[6], _ = self._export_pb(s_model)
|
||||
if tflite or edgetpu:
|
||||
f[7] = str(Path(f[5]) / (self.file.stem + '_float16.tflite'))
|
||||
# f[7], _ = self._export_tflite(s_model,
|
||||
# int8=self.args.int8 or edgetpu,
|
||||
# data=self.args.data,
|
||||
# nms=nms,
|
||||
# agnostic_nms=self.args.agnostic_nms)
|
||||
if edgetpu:
|
||||
f[8], _ = self._export_edgetpu(tflite_model=f[7])
|
||||
if tfjs:
|
||||
f[9], _ = self._export_tfjs()
|
||||
if paddle: # PaddlePaddle
|
||||
f[10], _ = self._export_paddle()
|
||||
|
||||
@ -258,13 +256,14 @@ class Exporter:
|
||||
if any(f):
|
||||
f = str(Path(f[-1]))
|
||||
square = self.imgsz[0] == self.imgsz[1]
|
||||
s = f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not work. Use " \
|
||||
f"export 'imgsz={max(self.imgsz)}' if val is required." if not square else ''
|
||||
s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
|
||||
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
||||
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
|
||||
data = f"data={self.args.data}" if model.task == 'segment' and format == 'pb' else ''
|
||||
LOGGER.info(
|
||||
f'\nExport complete ({time.time() - t:.1f}s)'
|
||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||
f"\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz}"
|
||||
f"\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}"
|
||||
f"\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}"
|
||||
f"\nVisualize: https://netron.app")
|
||||
|
||||
@ -335,7 +334,7 @@ class Exporter:
|
||||
check_requirements('onnxsim')
|
||||
import onnxsim
|
||||
|
||||
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
|
||||
LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
|
||||
subprocess.run(f'onnxsim {f} {f}', shell=True)
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix} simplifier failure: {e}')
|
||||
@ -358,7 +357,7 @@ class Exporter:
|
||||
framework="onnx",
|
||||
compress_to_fp16=self.args.half) # export
|
||||
ov.serialize(ov_model, f_ov) # save
|
||||
yaml_save(Path(f) / self.file.with_suffix('.yaml').name, self.metadata) # add metadata.yaml
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
@ -372,7 +371,7 @@ class Exporter:
|
||||
f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}')
|
||||
|
||||
pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export
|
||||
yaml_save(Path(f) / self.file.with_suffix('.yaml').name, self.metadata) # add metadata.yaml
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
@ -436,7 +435,7 @@ class Exporter:
|
||||
try:
|
||||
import tensorrt as trt # noqa
|
||||
except ImportError:
|
||||
if platform.system() == 'Linux':
|
||||
if LINUX:
|
||||
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
|
||||
import tensorrt as trt # noqa
|
||||
|
||||
@ -482,8 +481,16 @@ class Exporter:
|
||||
f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}')
|
||||
if builder.platform_has_fast_fp16 and self.args.half:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
# Write file
|
||||
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
||||
# Metadata
|
||||
meta = json.dumps(self.metadata)
|
||||
t.write(len(meta).to_bytes(4, byteorder='little', signed=True))
|
||||
t.write(meta.encode())
|
||||
# Model
|
||||
t.write(engine.serialize())
|
||||
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
@ -500,10 +507,10 @@ class Exporter:
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
except ImportError:
|
||||
check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
|
||||
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
|
||||
import tensorflow as tf # noqa
|
||||
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon", "tflite_support"),
|
||||
cmds="--extra-index-url https://pypi.ngc.nvidia.com ")
|
||||
cmds="--extra-index-url https://pypi.ngc.nvidia.com")
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, '_saved_model')
|
||||
@ -514,10 +521,11 @@ class Exporter:
|
||||
|
||||
# Export to TF SavedModel
|
||||
subprocess.run(f'onnx2tf -i {onnx} -o {f} --non_verbose', shell=True)
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
|
||||
# Add TFLite metadata
|
||||
for tflite_file in Path(f).rglob('*.tflite'):
|
||||
self._add_tflite_metadata(tflite_file)
|
||||
for file in Path(f).rglob('*.tflite'):
|
||||
self._add_tflite_metadata(file)
|
||||
|
||||
# Load saved_model
|
||||
keras_model = tf.saved_model.load(f, tags=None, options=None)
|
||||
@ -537,7 +545,7 @@ class Exporter:
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
except ImportError:
|
||||
check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
|
||||
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
|
||||
import tensorflow as tf # noqa
|
||||
# from models.tf import TFModel
|
||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
||||
@ -628,11 +636,11 @@ class Exporter:
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def _export_edgetpu(self, prefix=colorstr('Edge TPU:')):
|
||||
def _export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
|
||||
# YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
|
||||
cmd = 'edgetpu_compiler --version'
|
||||
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
|
||||
assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
|
||||
assert LINUX, f'export only supported on Linux. See {help_url}'
|
||||
if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
|
||||
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
|
||||
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
|
||||
@ -646,11 +654,11 @@ class Exporter:
|
||||
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
|
||||
f = str(self.file).replace(self.file.suffix, '-int8_edgetpu.tflite') # Edge TPU model
|
||||
f_tfl = str(self.file).replace(self.file.suffix, '-int8.tflite') # TFLite model
|
||||
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
|
||||
|
||||
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {f_tfl}"
|
||||
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}"
|
||||
subprocess.run(cmd.split(), check=True)
|
||||
self._add_tflite_metadata(f)
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
@ -681,6 +689,7 @@ class Exporter:
|
||||
f_json.read_text(),
|
||||
)
|
||||
j.write(subst)
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
def _add_tflite_metadata(self, file):
|
||||
@ -736,14 +745,6 @@ class Exporter:
|
||||
populator.populate()
|
||||
tmp_file.unlink()
|
||||
|
||||
# TODO Rename this here and in `_add_tflite_metadata`
|
||||
def _extracted_from__add_tflite_metadata_15(self, _metadata_fb, arg1, arg2):
|
||||
# Creates input info.
|
||||
result = _metadata_fb.TensorMetadataT()
|
||||
result.name = arg1
|
||||
result.description = arg2
|
||||
return result
|
||||
|
||||
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
|
||||
# YOLOv8 CoreML pipeline
|
||||
import coremltools as ct # noqa
|
||||
|
@ -42,6 +42,7 @@ class YOLO:
|
||||
model (str, Path): model to load or create
|
||||
type (str): Type/version of models to use. Defaults to "v8".
|
||||
"""
|
||||
self._reset_callbacks()
|
||||
self.type = type
|
||||
self.ModelClass = None # model class
|
||||
self.TrainerClass = None # trainer class
|
||||
@ -307,3 +308,8 @@ class YOLO:
|
||||
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset':
|
||||
args.pop(arg, None)
|
||||
|
||||
@staticmethod
|
||||
def _reset_callbacks():
|
||||
for event in callbacks.default_callbacks.keys():
|
||||
callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
|
@ -85,7 +85,6 @@ class BasePredictor:
|
||||
self.data = self.args.data # data_dict
|
||||
self.imgsz = None
|
||||
self.device = None
|
||||
self.classes = self.args.classes
|
||||
self.dataset = None
|
||||
self.vid_path, self.vid_writer = None, None
|
||||
self.annotator = None
|
||||
@ -103,7 +102,7 @@ class BasePredictor:
|
||||
def write_results(self, results, batch, print_string):
|
||||
raise NotImplementedError("print_results function needs to be implemented")
|
||||
|
||||
def postprocess(self, preds, img, orig_img, classes=None):
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
return preds
|
||||
|
||||
@smart_inference_mode()
|
||||
@ -170,13 +169,13 @@ class BasePredictor:
|
||||
|
||||
# postprocess
|
||||
with self.dt[2]:
|
||||
self.results = self.postprocess(preds, im, im0s, self.classes)
|
||||
self.results = self.postprocess(preds, im, im0s)
|
||||
self.run_callbacks("on_predict_postprocess_end")
|
||||
|
||||
# visualize, save, write results
|
||||
for i in range(len(im)):
|
||||
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img else (path,
|
||||
im0s)
|
||||
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \
|
||||
else (path, im0s.copy())
|
||||
p = Path(p)
|
||||
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
|
@ -1,9 +1,13 @@
|
||||
from copy import deepcopy
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors
|
||||
|
||||
|
||||
class Results:
|
||||
@ -14,22 +18,24 @@ class Results:
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||
orig_shape (tuple, optional): Original image size.
|
||||
orig_img (tuple, optional): Original image size.
|
||||
|
||||
Attributes:
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||
orig_shape (tuple, optional): Original image size.
|
||||
orig_img (tuple, optional): Original image size.
|
||||
data (torch.Tensor): The raw masks tensor
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, boxes=None, masks=None, probs=None, orig_shape=None) -> None:
|
||||
self.boxes = Boxes(boxes, orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
def __init__(self, boxes=None, masks=None, probs=None, orig_img=None, names=None) -> None:
|
||||
self.orig_img = orig_img
|
||||
self.orig_shape = orig_img.shape[:2]
|
||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs if probs is not None else None
|
||||
self.orig_shape = orig_shape
|
||||
self.names = names
|
||||
self.comp = ["boxes", "masks", "probs"]
|
||||
|
||||
def pandas(self):
|
||||
@ -37,7 +43,7 @@ class Results:
|
||||
# TODO masks.pandas + boxes.pandas + cls.pandas
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
r = Results(orig_img=self.orig_img)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -53,7 +59,7 @@ class Results:
|
||||
self.probs = probs
|
||||
|
||||
def cpu(self):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
r = Results(orig_img=self.orig_img)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -61,7 +67,7 @@ class Results:
|
||||
return r
|
||||
|
||||
def numpy(self):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
r = Results(orig_img=self.orig_img)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -69,7 +75,7 @@ class Results:
|
||||
return r
|
||||
|
||||
def cuda(self):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
r = Results(orig_img=self.orig_img)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -77,7 +83,7 @@ class Results:
|
||||
return r
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
r = Results(orig_img=self.orig_img)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
@ -118,6 +124,40 @@ class Results:
|
||||
orig_shape (tuple, optional): Original image size.
|
||||
""")
|
||||
|
||||
def visualize(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
"""
|
||||
Plots the given result on an input RGB image. Accepts cv2(numpy) or PIL Image
|
||||
|
||||
Args:
|
||||
show_conf (bool): Show confidence
|
||||
line_width (Float): The line width of boxes. Automatically scaled to img size if not provided
|
||||
font_size (Float): The font size of . Automatically scaled to img size if not provided
|
||||
"""
|
||||
img = deepcopy(self.orig_img)
|
||||
annotator = Annotator(img, line_width, font_size, font, pil, example)
|
||||
boxes = self.boxes
|
||||
masks = self.masks.data
|
||||
logits = self.probs
|
||||
names = self.names
|
||||
if boxes is not None:
|
||||
for d in reversed(boxes):
|
||||
cls, conf = d.cls.squeeze(), d.conf.squeeze()
|
||||
c = int(cls)
|
||||
label = (f'{names[c]}' if names else f'{c}') + (f'{conf:.2f}' if show_conf else '')
|
||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
|
||||
if masks is not None:
|
||||
im_gpu = torch.as_tensor(img, dtype=torch.float16).permute(2, 0, 1).flip(0).contiguous()
|
||||
im_gpu = F.resize(im_gpu, masks.data.shape[1:]) / 255
|
||||
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im_gpu)
|
||||
|
||||
if logits is not None:
|
||||
top5i = logits.argsort(0, descending=True)[:5].tolist() # top 5 indices
|
||||
text = f"{', '.join(f'{names[j] if names else j} {logits[j]:.2f}' for j in top5i)}, "
|
||||
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class Boxes:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user