ultralytics 8.0.41
TF SavedModel and EdgeTPU export (#1034)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -18,29 +18,28 @@ TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
|
||||
Requirements:
|
||||
$ pip install -r requirements.txt coremltools onnx onnxsim onnxruntime openvino-dev tensorflow-cpu # CPU
|
||||
$ pip install -r requirements.txt coremltools onnx onnxsim onnxruntime-gpu openvino-dev tensorflow # GPU
|
||||
$ pip install ultralytics[export]
|
||||
|
||||
Python:
|
||||
from ultralytics import YOLO
|
||||
model = YOLO('yolov8n.yaml')
|
||||
model = YOLO('yolov8n.pt')
|
||||
results = model.export(format='onnx')
|
||||
|
||||
CLI:
|
||||
$ yolo mode=export model=yolov8n.pt format=onnx
|
||||
|
||||
Inference:
|
||||
$ python detect.py --weights yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
$ yolo predict model=yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
|
||||
TensorFlow.js:
|
||||
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
||||
@ -64,12 +63,12 @@ import pandas as pd
|
||||
import torch
|
||||
|
||||
from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.nn.modules import Detect, Segment
|
||||
from ultralytics.nn.modules import C2f, Detect, Segment
|
||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||
from ultralytics.yolo.data.utils import IMAGENET_MEAN, IMAGENET_STD, check_det_dataset
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, WINDOWS, __version__, callbacks, colorstr,
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, __version__, callbacks, colorstr,
|
||||
get_default_args, yaml_save)
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
@ -77,6 +76,7 @@ from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
|
||||
|
||||
CUDA = torch.cuda.is_available()
|
||||
ARM64 = platform.machine() in ('arm64', 'aarch64')
|
||||
|
||||
|
||||
def export_formats():
|
||||
@ -157,11 +157,10 @@ class Exporter:
|
||||
|
||||
# Load PyTorch model
|
||||
self.device = select_device('cpu' if self.args.device is None else self.args.device)
|
||||
if self.args.half:
|
||||
if self.device.type == 'cpu' and not coreml and not xml:
|
||||
LOGGER.info('half=True only compatible with GPU or CoreML export, i.e. use device=0 or format=coreml')
|
||||
self.args.half = False
|
||||
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
|
||||
if self.args.half and onnx and self.device.type == 'cpu':
|
||||
LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0')
|
||||
self.args.half = False
|
||||
assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.'
|
||||
|
||||
# Checks
|
||||
model.names = check_class_names(model.names)
|
||||
@ -188,11 +187,15 @@ class Exporter:
|
||||
if isinstance(m, (Detect, Segment)):
|
||||
m.dynamic = self.args.dynamic
|
||||
m.export = True
|
||||
m.format = self.args.format
|
||||
elif isinstance(m, C2f) and not edgetpu:
|
||||
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
|
||||
m.forward = m.forward_split
|
||||
|
||||
y = None
|
||||
for _ in range(2):
|
||||
y = model(im) # dry runs
|
||||
if self.args.half and not coreml and not xml:
|
||||
if self.args.half and (engine or onnx) and self.device.type != 'cpu':
|
||||
im, model = im.half(), model.half() # to FP16
|
||||
|
||||
# Warnings
|
||||
@ -207,7 +210,7 @@ class Exporter:
|
||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
||||
self.metadata = {
|
||||
'description': f'Ultralytics {self.pretty_name} model trained on {self.args.data}',
|
||||
'description': f'Ultralytics {self.pretty_name} model trained on {Path(self.args.data).name}',
|
||||
'author': 'Ultralytics',
|
||||
'license': 'GPL-3.0 https://ultralytics.com/license',
|
||||
'version': __version__,
|
||||
@ -233,19 +236,16 @@ class Exporter:
|
||||
LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export is still under development. '
|
||||
'Please consider contributing to the effort if you have TF expertise. Thank you!')
|
||||
nms = False
|
||||
self.args.int8 |= edgetpu
|
||||
f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
|
||||
agnostic_nms=self.args.agnostic_nms or tfjs)
|
||||
if pb or tfjs: # pb prerequisite to tfjs
|
||||
f[6], _ = self._export_pb(s_model)
|
||||
if tflite or edgetpu:
|
||||
f[7] = str(Path(f[5]) / (self.file.stem + '_float16.tflite'))
|
||||
# f[7], _ = self._export_tflite(s_model,
|
||||
# int8=self.args.int8 or edgetpu,
|
||||
# data=self.args.data,
|
||||
# nms=nms,
|
||||
# agnostic_nms=self.args.agnostic_nms)
|
||||
if tflite:
|
||||
f[7], _ = self._export_tflite(s_model, nms=nms, agnostic_nms=self.args.agnostic_nms)
|
||||
if edgetpu:
|
||||
f[8], _ = self._export_edgetpu(tflite_model=f[7])
|
||||
f[8], _ = self._export_edgetpu(tflite_model=str(
|
||||
Path(f[5]) / (self.file.stem + '_full_integer_quant.tflite'))) # int8 in/out
|
||||
if tfjs:
|
||||
f[9], _ = self._export_tfjs()
|
||||
if paddle: # PaddlePaddle
|
||||
@ -263,8 +263,8 @@ class Exporter:
|
||||
LOGGER.info(
|
||||
f'\nExport complete ({time.time() - t:.1f}s)'
|
||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||
f'\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz} {data}'
|
||||
f'\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}'
|
||||
f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {data}'
|
||||
f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={self.args.data} {s}'
|
||||
f'\nVisualize: https://netron.app')
|
||||
|
||||
self.run_callbacks('on_export_end')
|
||||
@ -319,25 +319,27 @@ class Exporter:
|
||||
|
||||
# Checks
|
||||
model_onnx = onnx.load(f) # load onnx model
|
||||
onnx.checker.check_model(model_onnx) # check onnx model
|
||||
|
||||
# Metadata
|
||||
d = {'stride': int(max(self.model.stride)), 'names': self.model.names}
|
||||
for k, v in d.items():
|
||||
meta = model_onnx.metadata_props.add()
|
||||
meta.key, meta.value = k, str(v)
|
||||
onnx.save(model_onnx, f)
|
||||
# onnx.checker.check_model(model_onnx) # check onnx model
|
||||
|
||||
# Simplify
|
||||
if self.args.simplify:
|
||||
try:
|
||||
check_requirements('onnxsim')
|
||||
check_requirements(('onnxsim', 'onnxruntime-gpu' if CUDA else 'onnxruntime'))
|
||||
import onnxsim
|
||||
|
||||
LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
|
||||
subprocess.run(f'onnxsim {f} {f}', shell=True)
|
||||
# subprocess.run(f'onnxsim {f} {f}', shell=True)
|
||||
model_onnx, check = onnxsim.simplify(model_onnx)
|
||||
assert check, 'Simplified ONNX model could not be validated'
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix} simplifier failure: {e}')
|
||||
|
||||
# Metadata
|
||||
for k, v in self.metadata.items():
|
||||
meta = model_onnx.metadata_props.add()
|
||||
meta.key, meta.value = k, str(v)
|
||||
|
||||
onnx.save(model_onnx, f)
|
||||
return f, model_onnx
|
||||
|
||||
@try_export
|
||||
@ -402,7 +404,7 @@ class Exporter:
|
||||
if self.model.task == 'classify':
|
||||
bias = [-x for x in IMAGENET_MEAN]
|
||||
scale = 1 / 255 / (sum(IMAGENET_STD) / 3)
|
||||
classifier_config = ct.ClassifierConfig(list(self.model.names.values()))
|
||||
classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
|
||||
else:
|
||||
bias = [0.0, 0.0, 0.0]
|
||||
scale = 1 / 255
|
||||
@ -414,10 +416,7 @@ class Exporter:
|
||||
classifier_config=classifier_config)
|
||||
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
|
||||
if bits < 32:
|
||||
if MACOS: # quantization only supported on macOS
|
||||
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
||||
else:
|
||||
LOGGER.info(f'{prefix} quantization only supported on macOS, skipping...')
|
||||
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
||||
if self.args.nms:
|
||||
ct_model = self._pipeline_coreml(ct_model)
|
||||
|
||||
@ -440,11 +439,11 @@ class Exporter:
|
||||
import tensorrt as trt # noqa
|
||||
|
||||
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=8.0.0
|
||||
self._export_onnx()
|
||||
onnx = self.file.with_suffix('.onnx')
|
||||
self.args.simplify = True
|
||||
f_onnx, _ = self._export_onnx()
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
||||
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
||||
assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}'
|
||||
f = self.file.with_suffix('.engine') # TensorRT engine file
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
if verbose:
|
||||
@ -458,8 +457,8 @@ class Exporter:
|
||||
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
network = builder.create_network(flag)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
if not parser.parse_from_file(str(onnx)):
|
||||
raise RuntimeError(f'failed to load ONNX file: {onnx}')
|
||||
if not parser.parse_from_file(f_onnx):
|
||||
raise RuntimeError(f'failed to load ONNX file: {f_onnx}')
|
||||
|
||||
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
||||
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
||||
@ -507,77 +506,37 @@ class Exporter:
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
except ImportError:
|
||||
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
|
||||
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if CUDA else '-cpu'}")
|
||||
import tensorflow as tf # noqa
|
||||
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support'),
|
||||
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support',
|
||||
'onnxruntime-gpu' if CUDA else 'onnxruntime'),
|
||||
cmds='--extra-index-url https://pypi.ngc.nvidia.com')
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, '_saved_model')
|
||||
f = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
|
||||
if f.is_dir():
|
||||
import shutil
|
||||
shutil.rmtree(f) # delete output folder
|
||||
|
||||
# Export to ONNX
|
||||
self._export_onnx()
|
||||
onnx = self.file.with_suffix('.onnx')
|
||||
self.args.simplify = True
|
||||
f_onnx, _ = self._export_onnx()
|
||||
|
||||
# Export to TF SavedModel
|
||||
subprocess.run(f'onnx2tf -i {onnx} -o {f} --non_verbose', shell=True)
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
# Export to TF
|
||||
int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
|
||||
cmd = f'onnx2tf -i {f_onnx} -o {f} --non_verbose {int8}'
|
||||
LOGGER.info(f'\n{prefix} running {cmd}')
|
||||
subprocess.run(cmd, shell=True)
|
||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
|
||||
# Add TFLite metadata
|
||||
for file in Path(f).rglob('*.tflite'):
|
||||
for file in f.rglob('*.tflite'):
|
||||
self._add_tflite_metadata(file)
|
||||
|
||||
# Load saved_model
|
||||
keras_model = tf.saved_model.load(f, tags=None, options=None)
|
||||
|
||||
return f, keras_model
|
||||
|
||||
@try_export
|
||||
def _export_saved_model_OLD(self,
|
||||
nms=False,
|
||||
agnostic_nms=False,
|
||||
topk_per_class=100,
|
||||
topk_all=100,
|
||||
iou_thres=0.45,
|
||||
conf_thres=0.25,
|
||||
prefix=colorstr('TensorFlow SavedModel:')):
|
||||
# YOLOv8 TensorFlow SavedModel export
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
except ImportError:
|
||||
check_requirements(f"tensorflow{'' if CUDA else '-macos' if MACOS else '-cpu' if LINUX else ''}")
|
||||
import tensorflow as tf # noqa
|
||||
# from models.tf import TFModel
|
||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, '_saved_model')
|
||||
batch_size, ch, *imgsz = list(self.im.shape) # BCHW
|
||||
|
||||
tf_models = None # TODO: no TF modules available
|
||||
tf_model = tf_models.TFModel(cfg=self.model.yaml, model=self.model.cpu(), nc=self.model.nc, imgsz=imgsz)
|
||||
im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
|
||||
_ = tf_model.predict(im, nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
|
||||
inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if self.args.dynamic else batch_size)
|
||||
outputs = tf_model.predict(inputs, nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
|
||||
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
keras_model.trainable = False
|
||||
keras_model.summary()
|
||||
if self.args.keras:
|
||||
keras_model.save(f, save_format='tf')
|
||||
else:
|
||||
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
|
||||
m = tf.function(lambda x: keras_model(x)) # full model
|
||||
m = m.get_concrete_function(spec)
|
||||
frozen_func = convert_variables_to_constants_v2(m)
|
||||
tfm = tf.Module()
|
||||
tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if nms else frozen_func(x), [spec])
|
||||
tfm.__call__(im)
|
||||
tf.saved_model.save(tfm,
|
||||
f,
|
||||
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False)
|
||||
if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
|
||||
return f, keras_model
|
||||
return str(f), keras_model
|
||||
|
||||
@try_export
|
||||
def _export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
|
||||
@ -596,8 +555,18 @@ class Exporter:
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def _export_tflite(self, keras_model, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
|
||||
def _export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
|
||||
# YOLOv8 TensorFlow Lite export
|
||||
saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
|
||||
if self.args.int8:
|
||||
f = saved_model / (self.file.stem + 'yolov8n_integer_quant.tflite') # fp32 in/out
|
||||
elif self.args.half:
|
||||
f = saved_model / (self.file.stem + '_float16.tflite')
|
||||
else:
|
||||
f = saved_model / (self.file.stem + '_float32.tflite')
|
||||
return str(f), None # noqa
|
||||
|
||||
# OLD VERSION BELOW ---------------------------------------------------------------
|
||||
import tensorflow as tf # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
@ -608,7 +577,7 @@ class Exporter:
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
|
||||
converter.target_spec.supported_types = [tf.float16]
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
if int8:
|
||||
if self.args.int8:
|
||||
|
||||
def representative_dataset_gen(dataset, n_images=100):
|
||||
# Dataset generator for use with converter.representative_dataset, returns a generator of np arrays
|
||||
@ -620,7 +589,7 @@ class Exporter:
|
||||
if n >= n_images:
|
||||
break
|
||||
|
||||
dataset = LoadImages(check_det_dataset(check_yaml(data))['train'], imgsz=imgsz, auto=False)
|
||||
dataset = LoadImages(check_det_dataset(check_yaml(self.args.data))['train'], imgsz=imgsz, auto=False)
|
||||
converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100)
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
converter.target_spec.supported_types = []
|
||||
@ -641,7 +610,7 @@ class Exporter:
|
||||
cmd = 'edgetpu_compiler --version'
|
||||
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
|
||||
assert LINUX, f'export only supported on Linux. See {help_url}'
|
||||
if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
|
||||
if subprocess.run(f'{cmd} > /dev/null', shell=True).returncode != 0:
|
||||
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
|
||||
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
|
||||
for c in (
|
||||
@ -656,7 +625,7 @@ class Exporter:
|
||||
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
|
||||
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
|
||||
|
||||
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {self.file.parent} {tflite_model}'
|
||||
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {Path(f).parent} {tflite_model}'
|
||||
subprocess.run(cmd.split(), check=True)
|
||||
self._add_tflite_metadata(f)
|
||||
return f, None
|
||||
@ -674,7 +643,7 @@ class Exporter:
|
||||
|
||||
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
|
||||
f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
|
||||
subprocess.run(cmd.split())
|
||||
subprocess.run(cmd.split(), check=True)
|
||||
|
||||
with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
|
||||
subst = re.sub(
|
||||
@ -698,14 +667,23 @@ class Exporter:
|
||||
from tflite_support import metadata as _metadata # noqa
|
||||
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
||||
|
||||
# Creates model info.
|
||||
# Create model info
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.name = self.metadata['description']
|
||||
model_meta.version = self.metadata['version']
|
||||
model_meta.author = self.metadata['author']
|
||||
model_meta.license = self.metadata['license']
|
||||
|
||||
# Creates input info.
|
||||
# Label file
|
||||
tmp_file = file.parent / 'temp_meta.txt'
|
||||
with open(tmp_file, 'w') as f:
|
||||
f.write(str(self.metadata))
|
||||
|
||||
label_file = _metadata_fb.AssociatedFileT()
|
||||
label_file.name = tmp_file.name
|
||||
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
|
||||
|
||||
# Create input info
|
||||
input_meta = _metadata_fb.TensorMetadataT()
|
||||
input_meta.name = 'image'
|
||||
input_meta.description = 'Input image to be detected.'
|
||||
@ -714,25 +692,21 @@ class Exporter:
|
||||
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
||||
input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
|
||||
|
||||
# Creates output info.
|
||||
output_meta = _metadata_fb.TensorMetadataT()
|
||||
output_meta.name = 'output'
|
||||
output_meta.description = 'Coordinates of detected objects, class labels, and confidence score.'
|
||||
# Create output info
|
||||
output1 = _metadata_fb.TensorMetadataT()
|
||||
output1.name = 'output'
|
||||
output1.description = 'Coordinates of detected objects, class labels, and confidence score'
|
||||
output1.associatedFiles = [label_file]
|
||||
if self.model.task == 'segment':
|
||||
output2 = _metadata_fb.TensorMetadataT()
|
||||
output2.name = 'output'
|
||||
output2.description = 'Mask protos'
|
||||
output2.associatedFiles = [label_file]
|
||||
|
||||
# Label file
|
||||
tmp_file = Path('/tmp/meta.txt')
|
||||
with open(tmp_file, 'w') as meta_f:
|
||||
meta_f.write(str(self.metadata))
|
||||
|
||||
label_file = _metadata_fb.AssociatedFileT()
|
||||
label_file.name = tmp_file.name
|
||||
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
|
||||
output_meta.associatedFiles = [label_file]
|
||||
|
||||
# Creates subgraph info.
|
||||
# Create subgraph info
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
subgraph.inputTensorMetadata = [input_meta]
|
||||
subgraph.outputTensorMetadata = [output_meta]
|
||||
subgraph.outputTensorMetadata = [output1, output2] if self.model.task == 'segment' else [output1]
|
||||
model_meta.subgraphMetadata = [subgraph]
|
||||
|
||||
b = flatbuffers.Builder(0)
|
||||
|
@ -29,14 +29,45 @@ MODEL_MAP = {
|
||||
|
||||
class YOLO:
|
||||
"""
|
||||
YOLO
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
|
||||
A python interface which emulates a model-like behaviour by wrapping trainers.
|
||||
"""
|
||||
Args:
|
||||
model (str or Path): Path to the model file to load or create.
|
||||
type (str): Type/version of models to use. Defaults to "v8".
|
||||
|
||||
Attributes:
|
||||
type (str): Type/version of models being used.
|
||||
ModelClass (Any): Model class.
|
||||
TrainerClass (Any): Trainer class.
|
||||
ValidatorClass (Any): Validator class.
|
||||
PredictorClass (Any): Predictor class.
|
||||
predictor (Any): Predictor object.
|
||||
model (Any): Model object.
|
||||
trainer (Any): Trainer object.
|
||||
task (str): Type of model task.
|
||||
ckpt (Any): Checkpoint object if model loaded from *.pt file.
|
||||
cfg (str): Model configuration if loaded from *.yaml file.
|
||||
ckpt_path (str): Checkpoint file path.
|
||||
overrides (dict): Overrides for trainer object.
|
||||
metrics_data (Any): Data for metrics.
|
||||
|
||||
Methods:
|
||||
__call__(): Alias for predict method.
|
||||
_new(cfg, verbose=True): Initializes a new model and infers the task type from the model definitions.
|
||||
_load(weights): Initializes a new model and infers the task type from the model head.
|
||||
_check_is_pytorch_model(): Raises TypeError if model is not a PyTorch model.
|
||||
reset(): Resets the model modules.
|
||||
info(verbose=False): Logs model info.
|
||||
fuse(): Fuse model for faster inference.
|
||||
predict(source=None, stream=False, **kwargs): Perform prediction using the YOLO model.
|
||||
|
||||
Returns:
|
||||
List[ultralytics.yolo.engine.results.Results]: The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt', type='v8') -> None:
|
||||
"""
|
||||
Initializes the YOLO object.
|
||||
Initializes the YOLO model.
|
||||
|
||||
Args:
|
||||
model (str, Path): model to load or create
|
||||
@ -97,11 +128,12 @@ class YOLO:
|
||||
self.task = self.model.args['task']
|
||||
self.overrides = self.model.args
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
self.ckpt_path = self.model.pt_path
|
||||
else:
|
||||
check_file(weights)
|
||||
weights = check_file(weights)
|
||||
self.model, self.ckpt = weights, None
|
||||
self.task = guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.ckpt_path = weights
|
||||
self.overrides['model'] = weights
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||
|
||||
@ -204,7 +236,6 @@ class YOLO:
|
||||
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
@ -279,6 +310,13 @@ class YOLO:
|
||||
"""
|
||||
return self.model.names if hasattr(self.model, 'names') else None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""
|
||||
Returns device if PyTorch model
|
||||
"""
|
||||
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
||||
|
||||
@property
|
||||
def transforms(self):
|
||||
"""
|
||||
@ -293,7 +331,6 @@ class YOLO:
|
||||
"""
|
||||
if not self.metrics_data:
|
||||
LOGGER.info('No metrics data found! Run training or validation operation first.')
|
||||
|
||||
return self.metrics_data
|
||||
|
||||
@staticmethod
|
||||
@ -306,7 +343,7 @@ class YOLO:
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset':
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset', 'simplify':
|
||||
args.pop(arg, None)
|
||||
|
||||
@staticmethod
|
||||
|
@ -1,30 +1,32 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
"""
|
||||
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
||||
|
||||
Usage - sources:
|
||||
$ yolo task=... mode=predict model=s.pt --source 0 # webcam
|
||||
img.jpg # image
|
||||
vid.mp4 # video
|
||||
screen # screenshot
|
||||
path/ # directory
|
||||
list.txt # list of images
|
||||
list.streams # list of streams
|
||||
'path/*.jpg' # glob
|
||||
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
||||
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
||||
$ yolo mode=predict model=yolov8n.pt --source 0 # webcam
|
||||
img.jpg # image
|
||||
vid.mp4 # video
|
||||
screen # screenshot
|
||||
path/ # directory
|
||||
list.txt # list of images
|
||||
list.streams # list of streams
|
||||
'path/*.jpg' # glob
|
||||
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
||||
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
||||
|
||||
Usage - formats:
|
||||
$ yolo task=... mode=predict --weights yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
$ yolo mode=predict model=yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import platform
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@ -200,9 +202,9 @@ class BasePredictor:
|
||||
# Print results
|
||||
if self.args.verbose and self.seen:
|
||||
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape '
|
||||
LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
|
||||
f'{(1, 3, *self.imgsz)}' % t)
|
||||
if self.args.save_txt or self.args.save:
|
||||
if self.args.save or self.args.save_txt or self.args.save_crop:
|
||||
nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
|
||||
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
|
@ -4,7 +4,6 @@ from functools import lru_cache
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors
|
||||
@ -136,7 +135,7 @@ class Results:
|
||||
img = deepcopy(self.orig_img)
|
||||
annotator = Annotator(img, line_width, font_size, font, pil, example)
|
||||
boxes = self.boxes
|
||||
masks = self.masks.data
|
||||
masks = self.masks
|
||||
logits = self.probs
|
||||
names = self.names
|
||||
if boxes is not None:
|
||||
|
@ -1,8 +1,10 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
"""
|
||||
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
|
||||
"""
|
||||
Train a model on a dataset
|
||||
|
||||
Usage:
|
||||
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
@ -1,5 +1,23 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
"""
|
||||
Check a model's accuracy on a test or val split of a dataset
|
||||
|
||||
Usage:
|
||||
$ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
|
||||
|
||||
Usage - formats:
|
||||
$ yolo mode=val model=yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@ -105,8 +123,7 @@ class BaseValidator:
|
||||
self.device = model.device
|
||||
if not pt and not jit:
|
||||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
self.logger.info(
|
||||
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
self.logger.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
|
||||
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
@ -136,7 +153,7 @@ class BaseValidator:
|
||||
for batch_i, batch in enumerate(bar):
|
||||
self.run_callbacks('on_val_batch_start')
|
||||
self.batch_i = batch_i
|
||||
# pre-process
|
||||
# preprocess
|
||||
with dt[0]:
|
||||
batch = self.preprocess(batch)
|
||||
|
||||
@ -149,7 +166,7 @@ class BaseValidator:
|
||||
if self.training:
|
||||
self.loss += trainer.criterion(preds, batch)[1]
|
||||
|
||||
# pre-process predictions
|
||||
# postprocess
|
||||
with dt[3]:
|
||||
preds = self.postprocess(preds)
|
||||
|
||||
@ -163,13 +180,14 @@ class BaseValidator:
|
||||
self.check_stats(stats)
|
||||
self.print_results()
|
||||
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
|
||||
self.finalize_metrics()
|
||||
self.run_callbacks('on_val_end')
|
||||
if self.training:
|
||||
model.float()
|
||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||
else:
|
||||
self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
|
||||
self.logger.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
||||
self.speed)
|
||||
if self.args.save_json and self.jdict:
|
||||
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
||||
@ -197,6 +215,9 @@ class BaseValidator:
|
||||
def update_metrics(self, preds, batch):
|
||||
pass
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def get_stats(self):
|
||||
return {}
|
||||
|
||||
|
Reference in New Issue
Block a user