ultralytics 8.0.136
refactor and simplify package (#3748)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -0,0 +1,10 @@
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
# Set modules in sys.modules under their old name
|
||||
sys.modules['ultralytics.yolo.engine'] = importlib.import_module('ultralytics.engine')
|
||||
|
||||
LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.engine' is deprecated since '8.0.136' and will be removed in '8.1.0'. "
|
||||
"Please use 'ultralytics.engine' instead.")
|
||||
|
@ -1,926 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Export a YOLOv8 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
|
||||
|
||||
Format | `format=argument` | Model
|
||||
--- | --- | ---
|
||||
PyTorch | - | yolov8n.pt
|
||||
TorchScript | `torchscript` | yolov8n.torchscript
|
||||
ONNX | `onnx` | yolov8n.onnx
|
||||
OpenVINO | `openvino` | yolov8n_openvino_model/
|
||||
TensorRT | `engine` | yolov8n.engine
|
||||
CoreML | `coreml` | yolov8n.mlmodel
|
||||
TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
|
||||
TensorFlow GraphDef | `pb` | yolov8n.pb
|
||||
TensorFlow Lite | `tflite` | yolov8n.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
NCNN | `ncnn` | yolov8n_ncnn_model/
|
||||
|
||||
Requirements:
|
||||
$ pip install ultralytics[export]
|
||||
|
||||
Python:
|
||||
from ultralytics import YOLO
|
||||
model = YOLO('yolov8n.pt')
|
||||
results = model.export(format='onnx')
|
||||
|
||||
CLI:
|
||||
$ yolo mode=export model=yolov8n.pt format=onnx
|
||||
|
||||
Inference:
|
||||
$ yolo predict model=yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
|
||||
TensorFlow.js:
|
||||
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
||||
$ npm install
|
||||
$ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
|
||||
$ npm start
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, __version__, callbacks, colorstr,
|
||||
get_default_args, yaml_save)
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset, get_github_assets
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
|
||||
|
||||
|
||||
def export_formats():
|
||||
"""YOLOv8 export formats."""
|
||||
import pandas
|
||||
x = [
|
||||
['PyTorch', '-', '.pt', True, True],
|
||||
['TorchScript', 'torchscript', '.torchscript', True, True],
|
||||
['ONNX', 'onnx', '.onnx', True, True],
|
||||
['OpenVINO', 'openvino', '_openvino_model', True, False],
|
||||
['TensorRT', 'engine', '.engine', False, True],
|
||||
['CoreML', 'coreml', '.mlmodel', True, False],
|
||||
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
|
||||
['TensorFlow GraphDef', 'pb', '.pb', True, True],
|
||||
['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
||||
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
|
||||
['TensorFlow.js', 'tfjs', '_web_model', True, False],
|
||||
['PaddlePaddle', 'paddle', '_paddle_model', True, True],
|
||||
['NCNN', 'ncnn', '_ncnn_model', True, True], ]
|
||||
return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
||||
|
||||
|
||||
def gd_outputs(gd):
|
||||
"""TensorFlow GraphDef model output node names."""
|
||||
name_list, input_list = [], []
|
||||
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
||||
name_list.append(node.name)
|
||||
input_list.extend(node.input)
|
||||
return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
|
||||
|
||||
|
||||
def try_export(inner_func):
|
||||
"""YOLOv8 export decorator, i..e @try_export."""
|
||||
inner_args = get_default_args(inner_func)
|
||||
|
||||
def outer_func(*args, **kwargs):
|
||||
"""Export a model."""
|
||||
prefix = inner_args['prefix']
|
||||
try:
|
||||
with Profile() as dt:
|
||||
f, model = inner_func(*args, **kwargs)
|
||||
LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
|
||||
return f, model
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
|
||||
return None, None
|
||||
|
||||
return outer_func
|
||||
|
||||
|
||||
class Exporter:
|
||||
"""
|
||||
A class for exporting a model.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the exporter.
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initializes the Exporter class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
_callbacks (list, optional): List of callback functions. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, model=None):
|
||||
"""Returns list of exported files/dirs after running callbacks."""
|
||||
self.run_callbacks('on_export_start')
|
||||
t = time.time()
|
||||
format = self.args.format.lower() # to lowercase
|
||||
if format in ('tensorrt', 'trt'): # engine aliases
|
||||
format = 'engine'
|
||||
fmts = tuple(export_formats()['Argument'][1:]) # available export formats
|
||||
flags = [x == format for x in fmts]
|
||||
if sum(flags) != 1:
|
||||
raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}")
|
||||
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
|
||||
|
||||
# Load PyTorch model
|
||||
self.device = select_device('cpu' if self.args.device is None else self.args.device)
|
||||
|
||||
# Checks
|
||||
model.names = check_class_names(model.names)
|
||||
if self.args.half and onnx and self.device.type == 'cpu':
|
||||
LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0')
|
||||
self.args.half = False
|
||||
assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.'
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||
if self.args.optimize:
|
||||
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
|
||||
assert self.device.type == 'cpu', "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
|
||||
if edgetpu and not LINUX:
|
||||
raise SystemError('Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/')
|
||||
|
||||
# Input
|
||||
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
|
||||
file = Path(
|
||||
getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', ''))
|
||||
if file.suffix == '.yaml':
|
||||
file = Path(file.name)
|
||||
|
||||
# Update model
|
||||
model = deepcopy(model).to(self.device)
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
model.eval()
|
||||
model.float()
|
||||
model = model.fuse()
|
||||
for k, m in model.named_modules():
|
||||
if isinstance(m, (Detect, RTDETRDecoder)): # Segment and Pose use Detect base class
|
||||
m.dynamic = self.args.dynamic
|
||||
m.export = True
|
||||
m.format = self.args.format
|
||||
elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
|
||||
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
|
||||
m.forward = m.forward_split
|
||||
|
||||
y = None
|
||||
for _ in range(2):
|
||||
y = model(im) # dry runs
|
||||
if self.args.half and (engine or onnx) and self.device.type != 'cpu':
|
||||
im, model = im.half(), model.half() # to FP16
|
||||
|
||||
# Filter warnings
|
||||
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
||||
warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant missing ONNX warning
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
|
||||
|
||||
# Assign
|
||||
self.im = im
|
||||
self.model = model
|
||||
self.file = file
|
||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else \
|
||||
tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
|
||||
self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
|
||||
trained_on = f'trained on {Path(self.args.data).name}' if self.args.data else '(untrained)'
|
||||
description = f'Ultralytics {self.pretty_name} model {trained_on}'
|
||||
self.metadata = {
|
||||
'description': description,
|
||||
'author': 'Ultralytics',
|
||||
'license': 'AGPL-3.0 https://ultralytics.com/license',
|
||||
'date': datetime.now().isoformat(),
|
||||
'version': __version__,
|
||||
'stride': int(max(model.stride)),
|
||||
'task': model.task,
|
||||
'batch': self.args.batch,
|
||||
'imgsz': self.imgsz,
|
||||
'names': model.names} # model metadata
|
||||
if model.task == 'pose':
|
||||
self.metadata['kpt_shape'] = model.model[-1].kpt_shape
|
||||
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
||||
f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
|
||||
|
||||
# Exports
|
||||
f = [''] * len(fmts) # exported filenames
|
||||
if jit or ncnn: # TorchScript
|
||||
f[0], _ = self.export_torchscript()
|
||||
if engine: # TensorRT required before ONNX
|
||||
f[1], _ = self.export_engine()
|
||||
if onnx or xml: # OpenVINO requires ONNX
|
||||
f[2], _ = self.export_onnx()
|
||||
if xml: # OpenVINO
|
||||
f[3], _ = self.export_openvino()
|
||||
if coreml: # CoreML
|
||||
f[4], _ = self.export_coreml()
|
||||
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
||||
self.args.int8 |= edgetpu
|
||||
f[5], s_model = self.export_saved_model()
|
||||
if pb or tfjs: # pb prerequisite to tfjs
|
||||
f[6], _ = self.export_pb(s_model)
|
||||
if tflite:
|
||||
f[7], _ = self.export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms)
|
||||
if edgetpu:
|
||||
f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite')
|
||||
if tfjs:
|
||||
f[9], _ = self.export_tfjs()
|
||||
if paddle: # PaddlePaddle
|
||||
f[10], _ = self.export_paddle()
|
||||
if ncnn: # NCNN
|
||||
f[11], _ = self.export_ncnn()
|
||||
|
||||
# Finish
|
||||
f = [str(x) for x in f if x] # filter out '' and None
|
||||
if any(f):
|
||||
f = str(Path(f[-1]))
|
||||
square = self.imgsz[0] == self.imgsz[1]
|
||||
s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
|
||||
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
||||
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
|
||||
data = f'data={self.args.data}' if model.task == 'segment' and format == 'pb' else ''
|
||||
LOGGER.info(
|
||||
f'\nExport complete ({time.time() - t:.1f}s)'
|
||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||
f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {data}'
|
||||
f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={self.args.data} {s}'
|
||||
f'\nVisualize: https://netron.app')
|
||||
|
||||
self.run_callbacks('on_export_end')
|
||||
return f # return list of exported files/dirs
|
||||
|
||||
@try_export
|
||||
def export_torchscript(self, prefix=colorstr('TorchScript:')):
|
||||
"""YOLOv8 TorchScript model export."""
|
||||
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
|
||||
f = self.file.with_suffix('.torchscript')
|
||||
|
||||
ts = torch.jit.trace(self.model, self.im, strict=False)
|
||||
extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
|
||||
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
||||
LOGGER.info(f'{prefix} optimizing for mobile...')
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
|
||||
else:
|
||||
ts.save(str(f), _extra_files=extra_files)
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_onnx(self, prefix=colorstr('ONNX:')):
|
||||
"""YOLOv8 ONNX export."""
|
||||
requirements = ['onnx>=1.12.0']
|
||||
if self.args.simplify:
|
||||
requirements += ['onnxsim>=0.4.17', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
|
||||
check_requirements(requirements)
|
||||
import onnx # noqa
|
||||
|
||||
opset_version = self.args.opset or get_latest_opset()
|
||||
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...')
|
||||
f = str(self.file.with_suffix('.onnx'))
|
||||
|
||||
output_names = ['output0', 'output1'] if isinstance(self.model, SegmentationModel) else ['output0']
|
||||
dynamic = self.args.dynamic
|
||||
if dynamic:
|
||||
dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
|
||||
if isinstance(self.model, SegmentationModel):
|
||||
dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
|
||||
dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
|
||||
elif isinstance(self.model, DetectionModel):
|
||||
dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
|
||||
|
||||
torch.onnx.export(
|
||||
self.model.cpu() if dynamic else self.model, # --dynamic only compatible with cpu
|
||||
self.im.cpu() if dynamic else self.im,
|
||||
f,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
||||
input_names=['images'],
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic or None)
|
||||
|
||||
# Checks
|
||||
model_onnx = onnx.load(f) # load onnx model
|
||||
# onnx.checker.check_model(model_onnx) # check onnx model
|
||||
|
||||
# Simplify
|
||||
if self.args.simplify:
|
||||
try:
|
||||
import onnxsim
|
||||
|
||||
LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
|
||||
# subprocess.run(f'onnxsim {f} {f}', shell=True)
|
||||
model_onnx, check = onnxsim.simplify(model_onnx)
|
||||
assert check, 'Simplified ONNX model could not be validated'
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix} simplifier failure: {e}')
|
||||
|
||||
# Metadata
|
||||
for k, v in self.metadata.items():
|
||||
meta = model_onnx.metadata_props.add()
|
||||
meta.key, meta.value = k, str(v)
|
||||
|
||||
onnx.save(model_onnx, f)
|
||||
return f, model_onnx
|
||||
|
||||
@try_export
|
||||
def export_openvino(self, prefix=colorstr('OpenVINO:')):
|
||||
"""YOLOv8 OpenVINO export."""
|
||||
check_requirements('openvino-dev>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
||||
import openvino.runtime as ov # noqa
|
||||
from openvino.tools import mo # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with openvino {ov.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
|
||||
f_onnx = self.file.with_suffix('.onnx')
|
||||
f_ov = str(Path(f) / self.file.with_suffix('.xml').name)
|
||||
|
||||
ov_model = mo.convert_model(f_onnx,
|
||||
model_name=self.pretty_name,
|
||||
framework='onnx',
|
||||
compress_to_fp16=self.args.half) # export
|
||||
|
||||
# Set RT info
|
||||
ov_model.set_rt_info('YOLOv8', ['model_info', 'model_type'])
|
||||
ov_model.set_rt_info(True, ['model_info', 'reverse_input_channels'])
|
||||
ov_model.set_rt_info(114, ['model_info', 'pad_value'])
|
||||
ov_model.set_rt_info([255.0], ['model_info', 'scale_values'])
|
||||
ov_model.set_rt_info(self.args.iou, ['model_info', 'iou_threshold'])
|
||||
ov_model.set_rt_info([v.replace(' ', '_') for k, v in sorted(self.model.names.items())],
|
||||
['model_info', 'labels'])
|
||||
if self.model.task != 'classify':
|
||||
ov_model.set_rt_info('fit_to_window_letterbox', ['model_info', 'resize_type'])
|
||||
|
||||
ov.serialize(ov_model, f_ov) # save
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_paddle(self, prefix=colorstr('PaddlePaddle:')):
|
||||
"""YOLOv8 Paddle export."""
|
||||
check_requirements(('paddlepaddle', 'x2paddle'))
|
||||
import x2paddle # noqa
|
||||
from x2paddle.convert import pytorch2paddle # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}')
|
||||
|
||||
pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_ncnn(self, prefix=colorstr('NCNN:')):
|
||||
"""
|
||||
YOLOv8 NCNN export using PNNX https://github.com/pnnx/pnnx.
|
||||
"""
|
||||
check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires NCNN
|
||||
import ncnn # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with NCNN {ncnn.__version__}...')
|
||||
f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}'))
|
||||
f_ts = str(self.file.with_suffix('.torchscript'))
|
||||
|
||||
if Path('./pnnx').is_file():
|
||||
pnnx = './pnnx'
|
||||
elif (ROOT / 'pnnx').is_file():
|
||||
pnnx = ROOT / 'pnnx'
|
||||
else:
|
||||
LOGGER.warning(
|
||||
f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
|
||||
'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
|
||||
f'or in {ROOT}. See PNNX repo for full installation instructions.')
|
||||
_, assets = get_github_assets(repo='pnnx/pnnx')
|
||||
asset = [x for x in assets if ('macos' if MACOS else 'ubuntu' if LINUX else 'windows') in x][0]
|
||||
attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
|
||||
unzip_dir = Path(asset).with_suffix('')
|
||||
pnnx = ROOT / 'pnnx' # new location
|
||||
(unzip_dir / 'pnnx').rename(pnnx) # move binary to ROOT
|
||||
shutil.rmtree(unzip_dir) # delete unzip dir
|
||||
Path(asset).unlink() # delete zip
|
||||
pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
|
||||
|
||||
cmd = [
|
||||
str(pnnx),
|
||||
f_ts,
|
||||
f'pnnxparam={f / "model.pnnx.param"}',
|
||||
f'pnnxbin={f / "model.pnnx.bin"}',
|
||||
f'pnnxpy={f / "model_pnnx.py"}',
|
||||
f'pnnxonnx={f / "model.pnnx.onnx"}',
|
||||
f'ncnnparam={f / "model.ncnn.param"}',
|
||||
f'ncnnbin={f / "model.ncnn.bin"}',
|
||||
f'ncnnpy={f / "model_ncnn.py"}',
|
||||
f'fp16={int(self.args.half)}',
|
||||
f'device={self.device.type}',
|
||||
f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ]
|
||||
f.mkdir(exist_ok=True) # make ncnn_model directory
|
||||
LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
|
||||
subprocess.run(cmd, check=True)
|
||||
for f_debug in 'debug.bin', 'debug.param', 'debug2.bin', 'debug2.param': # remove debug files
|
||||
Path(f_debug).unlink(missing_ok=True)
|
||||
|
||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return str(f), None
|
||||
|
||||
@try_export
|
||||
def export_coreml(self, prefix=colorstr('CoreML:')):
|
||||
"""YOLOv8 CoreML export."""
|
||||
check_requirements('coremltools>=6.0')
|
||||
import coremltools as ct # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
|
||||
f = self.file.with_suffix('.mlmodel')
|
||||
|
||||
bias = [0.0, 0.0, 0.0]
|
||||
scale = 1 / 255
|
||||
classifier_config = None
|
||||
if self.model.task == 'classify':
|
||||
classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
|
||||
model = self.model
|
||||
elif self.model.task == 'detect':
|
||||
model = iOSDetectModel(self.model, self.im) if self.args.nms else self.model
|
||||
else:
|
||||
# TODO CoreML Segment and Pose model pipelining
|
||||
model = self.model
|
||||
|
||||
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
|
||||
ct_model = ct.convert(ts,
|
||||
inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)],
|
||||
classifier_config=classifier_config)
|
||||
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
|
||||
if bits < 32:
|
||||
if 'kmeans' in mode:
|
||||
check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
|
||||
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
||||
if self.args.nms and self.model.task == 'detect':
|
||||
ct_model = self._pipeline_coreml(ct_model)
|
||||
|
||||
m = self.metadata # metadata dict
|
||||
ct_model.short_description = m.pop('description')
|
||||
ct_model.author = m.pop('author')
|
||||
ct_model.license = m.pop('license')
|
||||
ct_model.version = m.pop('version')
|
||||
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
|
||||
ct_model.save(str(f))
|
||||
return f, ct_model
|
||||
|
||||
@try_export
|
||||
def export_engine(self, prefix=colorstr('TensorRT:')):
|
||||
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
|
||||
assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
||||
try:
|
||||
import tensorrt as trt # noqa
|
||||
except ImportError:
|
||||
if LINUX:
|
||||
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
|
||||
import tensorrt as trt # noqa
|
||||
|
||||
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
||||
self.args.simplify = True
|
||||
f_onnx, _ = self.export_onnx()
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
||||
assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}'
|
||||
f = self.file.with_suffix('.engine') # TensorRT engine file
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
if self.args.verbose:
|
||||
logger.min_severity = trt.Logger.Severity.VERBOSE
|
||||
|
||||
builder = trt.Builder(logger)
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = self.args.workspace * 1 << 30
|
||||
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
|
||||
|
||||
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
network = builder.create_network(flag)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
if not parser.parse_from_file(f_onnx):
|
||||
raise RuntimeError(f'failed to load ONNX file: {f_onnx}')
|
||||
|
||||
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
||||
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
||||
for inp in inputs:
|
||||
LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
|
||||
for out in outputs:
|
||||
LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
|
||||
|
||||
if self.args.dynamic:
|
||||
shape = self.im.shape
|
||||
if shape[0] <= 1:
|
||||
LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
|
||||
profile = builder.create_optimization_profile()
|
||||
for inp in inputs:
|
||||
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
|
||||
config.add_optimization_profile(profile)
|
||||
|
||||
LOGGER.info(
|
||||
f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}')
|
||||
if builder.platform_has_fast_fp16 and self.args.half:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
# Write file
|
||||
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
||||
# Metadata
|
||||
meta = json.dumps(self.metadata)
|
||||
t.write(len(meta).to_bytes(4, byteorder='little', signed=True))
|
||||
t.write(meta.encode())
|
||||
# Model
|
||||
t.write(engine.serialize())
|
||||
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')):
|
||||
"""YOLOv8 TensorFlow SavedModel export."""
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
except ImportError:
|
||||
cuda = torch.cuda.is_available()
|
||||
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
|
||||
import tensorflow as tf # noqa
|
||||
check_requirements(('onnx', 'onnx2tf>=1.7.7', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.17', 'onnx_graphsurgeon>=0.3.26',
|
||||
'tflite_support', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
|
||||
cmds='--extra-index-url https://pypi.ngc.nvidia.com')
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
f = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
|
||||
if f.is_dir():
|
||||
import shutil
|
||||
shutil.rmtree(f) # delete output folder
|
||||
|
||||
# Export to ONNX
|
||||
self.args.simplify = True
|
||||
f_onnx, _ = self.export_onnx()
|
||||
|
||||
# Export to TF
|
||||
int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
|
||||
cmd = f'onnx2tf -i {f_onnx} -o {f} -nuo --non_verbose {int8}'
|
||||
LOGGER.info(f"\n{prefix} running '{cmd.strip()}'")
|
||||
subprocess.run(cmd, shell=True)
|
||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
|
||||
# Remove/rename TFLite models
|
||||
if self.args.int8:
|
||||
for file in f.rglob('*_dynamic_range_quant.tflite'):
|
||||
file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
|
||||
for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
|
||||
file.unlink() # delete extra fp16 activation TFLite files
|
||||
|
||||
# Add TFLite metadata
|
||||
for file in f.rglob('*.tflite'):
|
||||
f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file)
|
||||
|
||||
# Load saved_model
|
||||
keras_model = tf.saved_model.load(f, tags=None, options=None)
|
||||
|
||||
return str(f), keras_model
|
||||
|
||||
@try_export
|
||||
def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
|
||||
"""YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
|
||||
import tensorflow as tf # noqa
|
||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
f = self.file.with_suffix('.pb')
|
||||
|
||||
m = tf.function(lambda x: keras_model(x)) # full model
|
||||
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
||||
frozen_func = convert_variables_to_constants_v2(m)
|
||||
frozen_func.graph.as_graph_def()
|
||||
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
|
||||
"""YOLOv8 TensorFlow Lite export."""
|
||||
import tensorflow as tf # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
|
||||
if self.args.int8:
|
||||
f = saved_model / f'{self.file.stem}_int8.tflite' # fp32 in/out
|
||||
elif self.args.half:
|
||||
f = saved_model / f'{self.file.stem}_float16.tflite' # fp32 in/out
|
||||
else:
|
||||
f = saved_model / f'{self.file.stem}_float32.tflite'
|
||||
return str(f), None
|
||||
|
||||
@try_export
|
||||
def export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
|
||||
"""YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
|
||||
LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185')
|
||||
|
||||
cmd = 'edgetpu_compiler --version'
|
||||
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
|
||||
assert LINUX, f'export only supported on Linux. See {help_url}'
|
||||
if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
|
||||
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
|
||||
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
|
||||
for c in (
|
||||
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
|
||||
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
|
||||
'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
|
||||
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
|
||||
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
|
||||
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
|
||||
|
||||
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {Path(f).parent} {tflite_model}'
|
||||
LOGGER.info(f"{prefix} running '{cmd}'")
|
||||
subprocess.run(cmd.split(), check=True)
|
||||
self._add_tflite_metadata(f)
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
|
||||
"""YOLOv8 TensorFlow.js export."""
|
||||
check_requirements('tensorflowjs')
|
||||
import tensorflow as tf
|
||||
import tensorflowjs as tfjs # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, '_web_model') # js dir
|
||||
f_pb = self.file.with_suffix('.pb') # *.pb path
|
||||
|
||||
gd = tf.Graph().as_graph_def() # TF GraphDef
|
||||
with open(f_pb, 'rb') as file:
|
||||
gd.ParseFromString(file.read())
|
||||
outputs = ','.join(gd_outputs(gd))
|
||||
LOGGER.info(f'\n{prefix} output node names: {outputs}')
|
||||
|
||||
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model --output_node_names={outputs} {f_pb} {f}'
|
||||
subprocess.run(cmd.split(), check=True)
|
||||
|
||||
# f_json = Path(f) / 'model.json' # *.json path
|
||||
# with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
|
||||
# subst = re.sub(
|
||||
# r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||
# r'"Identity.?.?": {"name": "Identity.?.?"}}}',
|
||||
# r'{"outputs": {"Identity": {"name": "Identity"}, '
|
||||
# r'"Identity_1": {"name": "Identity_1"}, '
|
||||
# r'"Identity_2": {"name": "Identity_2"}, '
|
||||
# r'"Identity_3": {"name": "Identity_3"}}}',
|
||||
# f_json.read_text(),
|
||||
# )
|
||||
# j.write(subst)
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
def _add_tflite_metadata(self, file):
|
||||
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
|
||||
from tflite_support import flatbuffers # noqa
|
||||
from tflite_support import metadata as _metadata # noqa
|
||||
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
||||
|
||||
# Create model info
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.name = self.metadata['description']
|
||||
model_meta.version = self.metadata['version']
|
||||
model_meta.author = self.metadata['author']
|
||||
model_meta.license = self.metadata['license']
|
||||
|
||||
# Label file
|
||||
tmp_file = Path(file).parent / 'temp_meta.txt'
|
||||
with open(tmp_file, 'w') as f:
|
||||
f.write(str(self.metadata))
|
||||
|
||||
label_file = _metadata_fb.AssociatedFileT()
|
||||
label_file.name = tmp_file.name
|
||||
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
|
||||
|
||||
# Create input info
|
||||
input_meta = _metadata_fb.TensorMetadataT()
|
||||
input_meta.name = 'image'
|
||||
input_meta.description = 'Input image to be detected.'
|
||||
input_meta.content = _metadata_fb.ContentT()
|
||||
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
|
||||
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
||||
input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
|
||||
|
||||
# Create output info
|
||||
output1 = _metadata_fb.TensorMetadataT()
|
||||
output1.name = 'output'
|
||||
output1.description = 'Coordinates of detected objects, class labels, and confidence score'
|
||||
output1.associatedFiles = [label_file]
|
||||
if self.model.task == 'segment':
|
||||
output2 = _metadata_fb.TensorMetadataT()
|
||||
output2.name = 'output'
|
||||
output2.description = 'Mask protos'
|
||||
output2.associatedFiles = [label_file]
|
||||
|
||||
# Create subgraph info
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
subgraph.inputTensorMetadata = [input_meta]
|
||||
subgraph.outputTensorMetadata = [output1, output2] if self.model.task == 'segment' else [output1]
|
||||
model_meta.subgraphMetadata = [subgraph]
|
||||
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
metadata_buf = b.Output()
|
||||
|
||||
populator = _metadata.MetadataPopulator.with_model_file(str(file))
|
||||
populator.load_metadata_buffer(metadata_buf)
|
||||
populator.load_associated_files([str(tmp_file)])
|
||||
populator.populate()
|
||||
tmp_file.unlink()
|
||||
|
||||
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
|
||||
"""YOLOv8 CoreML pipeline."""
|
||||
import coremltools as ct # noqa
|
||||
|
||||
LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
|
||||
batch_size, ch, h, w = list(self.im.shape) # BCHW
|
||||
|
||||
# Output shapes
|
||||
spec = model.get_spec()
|
||||
out0, out1 = iter(spec.description.output)
|
||||
if MACOS:
|
||||
from PIL import Image
|
||||
img = Image.new('RGB', (w, h)) # img(192 width, 320 height)
|
||||
# img = torch.zeros((*opt.img_size, 3)).numpy() # img size(320,192,3) iDetection
|
||||
out = model.predict({'image': img})
|
||||
out0_shape = out[out0.name].shape
|
||||
out1_shape = out[out1.name].shape
|
||||
else: # linux and windows can not run model.predict(), get sizes from pytorch output y
|
||||
out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
|
||||
out1_shape = self.output_shape[2], 4 # (3780, 4)
|
||||
|
||||
# Checks
|
||||
names = self.metadata['names']
|
||||
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
|
||||
na, nc = out0_shape
|
||||
# na, nc = out0.type.multiArrayType.shape # number anchors, classes
|
||||
assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check
|
||||
|
||||
# Define output shapes (missing)
|
||||
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
|
||||
out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
|
||||
# spec.neuralNetwork.preprocessing[0].featureName = '0'
|
||||
|
||||
# Flexible input shapes
|
||||
# from coremltools.models.neural_network import flexible_shape_utils
|
||||
# s = [] # shapes
|
||||
# s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
|
||||
# s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
|
||||
# flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
|
||||
# r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
|
||||
# r.add_height_range((192, 640))
|
||||
# r.add_width_range((192, 640))
|
||||
# flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
|
||||
|
||||
# Print
|
||||
# print(spec.description)
|
||||
|
||||
# Model from spec
|
||||
model = ct.models.MLModel(spec)
|
||||
|
||||
# 3. Create NMS protobuf
|
||||
nms_spec = ct.proto.Model_pb2.Model()
|
||||
nms_spec.specificationVersion = 5
|
||||
for i in range(2):
|
||||
decoder_output = model._spec.description.output[i].SerializeToString()
|
||||
nms_spec.description.input.add()
|
||||
nms_spec.description.input[i].ParseFromString(decoder_output)
|
||||
nms_spec.description.output.add()
|
||||
nms_spec.description.output[i].ParseFromString(decoder_output)
|
||||
|
||||
nms_spec.description.output[0].name = 'confidence'
|
||||
nms_spec.description.output[1].name = 'coordinates'
|
||||
|
||||
output_sizes = [nc, 4]
|
||||
for i in range(2):
|
||||
ma_type = nms_spec.description.output[i].type.multiArrayType
|
||||
ma_type.shapeRange.sizeRanges.add()
|
||||
ma_type.shapeRange.sizeRanges[0].lowerBound = 0
|
||||
ma_type.shapeRange.sizeRanges[0].upperBound = -1
|
||||
ma_type.shapeRange.sizeRanges.add()
|
||||
ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
|
||||
ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
|
||||
del ma_type.shape[:]
|
||||
|
||||
nms = nms_spec.nonMaximumSuppression
|
||||
nms.confidenceInputFeatureName = out0.name # 1x507x80
|
||||
nms.coordinatesInputFeatureName = out1.name # 1x507x4
|
||||
nms.confidenceOutputFeatureName = 'confidence'
|
||||
nms.coordinatesOutputFeatureName = 'coordinates'
|
||||
nms.iouThresholdInputFeatureName = 'iouThreshold'
|
||||
nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
|
||||
nms.iouThreshold = 0.45
|
||||
nms.confidenceThreshold = 0.25
|
||||
nms.pickTop.perClass = True
|
||||
nms.stringClassLabels.vector.extend(names.values())
|
||||
nms_model = ct.models.MLModel(nms_spec)
|
||||
|
||||
# 4. Pipeline models together
|
||||
pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)),
|
||||
('iouThreshold', ct.models.datatypes.Double()),
|
||||
('confidenceThreshold', ct.models.datatypes.Double())],
|
||||
output_features=['confidence', 'coordinates'])
|
||||
pipeline.add_model(model)
|
||||
pipeline.add_model(nms_model)
|
||||
|
||||
# Correct datatypes
|
||||
pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
|
||||
pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
|
||||
pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
|
||||
|
||||
# Update metadata
|
||||
pipeline.spec.specificationVersion = 5
|
||||
pipeline.spec.description.metadata.userDefined.update({
|
||||
'IoU threshold': str(nms.iouThreshold),
|
||||
'Confidence threshold': str(nms.confidenceThreshold)})
|
||||
|
||||
# Save the model
|
||||
model = ct.models.MLModel(pipeline.spec)
|
||||
model.input_description['image'] = 'Input image'
|
||||
model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})'
|
||||
model.input_description['confidenceThreshold'] = \
|
||||
f'(optional) Confidence threshold override (default: {nms.confidenceThreshold})'
|
||||
model.output_description['confidence'] = 'Boxes × Class confidence (see user-defined metadata "classes")'
|
||||
model.output_description['coordinates'] = 'Boxes × [x, y, width, height] (relative to image size)'
|
||||
LOGGER.info(f'{prefix} pipeline success')
|
||||
return model
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
Appends the given callback.
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Execute all callbacks for a given event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
|
||||
class iOSDetectModel(torch.nn.Module):
|
||||
"""Wrap an Ultralytics YOLO model for iOS export."""
|
||||
|
||||
def __init__(self, model, im):
|
||||
"""Initialize the iOSDetectModel class with a YOLO model and example image."""
|
||||
super().__init__()
|
||||
b, c, h, w = im.shape # batch, channel, height, width
|
||||
self.model = model
|
||||
self.nc = len(model.names) # number of classes
|
||||
if w == h:
|
||||
self.normalize = 1.0 / w # scalar
|
||||
else:
|
||||
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
|
||||
|
||||
def forward(self, x):
|
||||
"""Normalize predictions of object detection model with input size-dependent factors."""
|
||||
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
|
||||
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
||||
|
||||
|
||||
def export(cfg=DEFAULT_CFG):
|
||||
"""Export a YOLOv model to a specific format."""
|
||||
cfg.model = cfg.model or 'yolov8n.yaml'
|
||||
cfg.format = cfg.format or 'torchscript'
|
||||
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.export(**vars(cfg))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
CLI:
|
||||
yolo mode=export model=yolov8n.yaml format=onnx
|
||||
"""
|
||||
export()
|
@ -1,436 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel,
|
||||
attempt_load_one_weight, guess_model_task, nn, yaml_model_load)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
||||
is_git_dir, yaml_load)
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
TASK_MAP = {
|
||||
'classify': [
|
||||
ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator,
|
||||
yolo.v8.classify.ClassificationPredictor],
|
||||
'detect': [
|
||||
DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator,
|
||||
yolo.v8.detect.DetectionPredictor],
|
||||
'segment': [
|
||||
SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator,
|
||||
yolo.v8.segment.SegmentationPredictor],
|
||||
'pose': [PoseModel, yolo.v8.pose.PoseTrainer, yolo.v8.pose.PoseValidator, yolo.v8.pose.PosePredictor]}
|
||||
|
||||
|
||||
class YOLO:
|
||||
"""
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
|
||||
Args:
|
||||
model (str, Path): Path to the model file to load or create.
|
||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
predictor (Any): The predictor object.
|
||||
model (Any): The model object.
|
||||
trainer (Any): The trainer object.
|
||||
task (str): The type of model task.
|
||||
ckpt (Any): The checkpoint object if the model loaded from *.pt file.
|
||||
cfg (str): The model configuration if loaded from *.yaml file.
|
||||
ckpt_path (str): The checkpoint file path.
|
||||
overrides (dict): Overrides for the trainer object.
|
||||
metrics (Any): The data for metrics.
|
||||
|
||||
Methods:
|
||||
__call__(source=None, stream=False, **kwargs):
|
||||
Alias for the predict method.
|
||||
_new(cfg:str, verbose:bool=True) -> None:
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
_load(weights:str, task:str='') -> None:
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
_check_is_pytorch_model() -> None:
|
||||
Raises TypeError if the model is not a PyTorch model.
|
||||
reset() -> None:
|
||||
Resets the model modules.
|
||||
info(verbose:bool=False) -> None:
|
||||
Logs the model info.
|
||||
fuse() -> None:
|
||||
Fuses the model for faster inference.
|
||||
predict(source=None, stream=False, **kwargs) -> List[ultralytics.yolo.engine.results.Results]:
|
||||
Performs prediction using the YOLO model.
|
||||
|
||||
Returns:
|
||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
|
||||
"""
|
||||
Initializes the YOLO model.
|
||||
|
||||
Args:
|
||||
model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
|
||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||
"""
|
||||
self.callbacks = callbacks.get_default_callbacks()
|
||||
self.predictor = None # reuse predictor
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
self.task = None # task type
|
||||
self.ckpt = None # if loaded from *.pt
|
||||
self.cfg = None # if loaded from *.yaml
|
||||
self.ckpt_path = None
|
||||
self.overrides = {} # overrides for trainer object
|
||||
self.metrics = None # validation/training metrics
|
||||
self.session = None # HUB session
|
||||
model = str(model).strip() # strip spaces
|
||||
|
||||
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||
if self.is_hub_model(model):
|
||||
from ultralytics.hub.session import HUBTrainingSession
|
||||
self.session = HUBTrainingSession(model)
|
||||
model = self.session.model_file
|
||||
|
||||
# Load or create new YOLO model
|
||||
suffix = Path(model).suffix
|
||||
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
||||
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
if suffix == '.yaml':
|
||||
self._new(model, task)
|
||||
else:
|
||||
self._load(model, task)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Raises error if object has no requested attribute."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
@staticmethod
|
||||
def is_hub_model(model):
|
||||
"""Check if the provided model is a HUB model."""
|
||||
return any((
|
||||
model.startswith('https://hub.ultralytics.com/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
||||
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
||||
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
|
||||
|
||||
def _new(self, cfg: str, task=None, verbose=True):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
|
||||
Args:
|
||||
cfg (str): model configuration file
|
||||
task (str | None): model task
|
||||
verbose (bool): display model info on load
|
||||
"""
|
||||
cfg_dict = yaml_model_load(cfg)
|
||||
self.cfg = cfg
|
||||
self.task = task or guess_model_task(cfg_dict)
|
||||
self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
|
||||
self.overrides['model'] = self.cfg
|
||||
|
||||
# Below added to allow export from yamls
|
||||
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
|
||||
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
self.model.task = self.task
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
|
||||
Args:
|
||||
weights (str): model checkpoint to be loaded
|
||||
task (str | None): model task
|
||||
"""
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == '.pt':
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.task = self.model.args['task']
|
||||
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
||||
self.ckpt_path = self.model.pt_path
|
||||
else:
|
||||
weights = check_file(weights)
|
||||
self.model, self.ckpt = weights, None
|
||||
self.task = task or guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.overrides['model'] = weights
|
||||
self.overrides['task'] = self.task
|
||||
|
||||
def _check_is_pytorch_model(self):
|
||||
"""
|
||||
Raises TypeError is model is not a PyTorch model
|
||||
"""
|
||||
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
|
||||
pt_module = isinstance(self.model, nn.Module)
|
||||
if not (pt_module or pt_str):
|
||||
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
|
||||
f'PyTorch models can be used to train, val, predict and export, i.e. '
|
||||
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
||||
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
||||
|
||||
@smart_inference_mode()
|
||||
def reset_weights(self):
|
||||
"""
|
||||
Resets the model modules parameters to randomly initialized values, losing all training information.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True
|
||||
return self
|
||||
|
||||
@smart_inference_mode()
|
||||
def load(self, weights='yolov8n.pt'):
|
||||
"""
|
||||
Transfers parameters with matching names and shapes from 'weights' to model.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
if isinstance(weights, (str, Path)):
|
||||
weights, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.model.load(weights)
|
||||
return self
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
return self.model.info(detailed=detailed, verbose=verbose)
|
||||
|
||||
def fuse(self):
|
||||
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||
**kwargs : Additional keyword arguments passed to the predictor.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
|
||||
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||
assert overrides['mode'] in ['track', 'predict']
|
||||
if not is_cli:
|
||||
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
||||
if not self.predictor:
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
|
||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
if 'project' in overrides or 'name' in overrides:
|
||||
self.predictor.save_dir = self.predictor.get_save_dir()
|
||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||
|
||||
def track(self, source=None, stream=False, persist=False, **kwargs):
|
||||
"""
|
||||
Perform object tracking on the input source using the registered trackers.
|
||||
|
||||
Args:
|
||||
source (str, optional): The input source for object tracking. Can be a file path or a video stream.
|
||||
stream (bool, optional): Whether the input source is a video stream. Defaults to False.
|
||||
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
||||
**kwargs (optional): Additional keyword arguments for the tracking process.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The tracking results.
|
||||
|
||||
"""
|
||||
if not hasattr(self.predictor, 'trackers'):
|
||||
from ultralytics.tracker import register_tracker
|
||||
register_tracker(self, persist)
|
||||
# ByteTrack-based method needs low confidence predictions as input
|
||||
conf = kwargs.get('conf') or 0.1
|
||||
kwargs['conf'] = conf
|
||||
kwargs['mode'] = 'track'
|
||||
return self.predict(source=source, stream=stream, **kwargs)
|
||||
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
"""
|
||||
Validate a model on a given dataset.
|
||||
|
||||
Args:
|
||||
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides['rect'] = True # rect batches as default
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'val'
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.data = data or args.data
|
||||
if 'task' in overrides:
|
||||
self.task = args.task
|
||||
else:
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
|
||||
validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks)
|
||||
validator(model=self.model)
|
||||
self.metrics = validator.metrics
|
||||
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
def benchmark(self, **kwargs):
|
||||
"""
|
||||
Benchmark a model on all export formats.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
from ultralytics.yolo.utils.benchmarks import benchmark
|
||||
overrides = self.model.args.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'benchmark'
|
||||
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
|
||||
return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
|
||||
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'export'
|
||||
if overrides.get('imgsz') is None:
|
||||
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if 'batch' not in kwargs:
|
||||
overrides['batch'] = 1 # default to 1 if not modified
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Trains the model on a given dataset.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
if self.session: # Ultralytics HUB session
|
||||
if any(kwargs):
|
||||
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
||||
kwargs = self.session.train_args
|
||||
check_pip_update_available()
|
||||
overrides = self.overrides.copy()
|
||||
if kwargs.get('cfg'):
|
||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
||||
overrides = yaml_load(check_yaml(kwargs['cfg']))
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'train'
|
||||
if not overrides.get('data'):
|
||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||
if overrides.get('resume'):
|
||||
overrides['resume'] = self.ckpt_path
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
|
||||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
self.trainer.hub_session = self.session # attach optional HUB session
|
||||
self.trainer.train()
|
||||
# Update model and cfg after training
|
||||
if RANK in (-1, 0):
|
||||
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
||||
self.overrides = self.model.args
|
||||
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
|
||||
|
||||
def to(self, device):
|
||||
"""
|
||||
Sends the model to the given device.
|
||||
|
||||
Args:
|
||||
device (str): device
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.to(device)
|
||||
|
||||
def tune(self, *args, **kwargs):
|
||||
"""
|
||||
Runs hyperparameter tuning using Ray Tune. See ultralytics.yolo.utils.tuner.run_ray_tune for Args.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the results of the hyperparameter search.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If Ray Tune is not installed.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
from ultralytics.yolo.utils.tuner import run_ray_tune
|
||||
return run_ray_tune(self, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
"""Returns class names of the loaded model."""
|
||||
return self.model.names if hasattr(self.model, 'names') else None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""Returns device if PyTorch model."""
|
||||
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
||||
|
||||
@property
|
||||
def transforms(self):
|
||||
"""Returns transform of the loaded model."""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
def add_callback(self, event: str, func):
|
||||
"""Add a callback."""
|
||||
self.callbacks[event].append(func)
|
||||
|
||||
def clear_callback(self, event: str):
|
||||
"""Clear all event callbacks."""
|
||||
self.callbacks[event] = []
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
"""Reset arguments when loading a PyTorch model."""
|
||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
|
||||
def _reset_callbacks(self):
|
||||
"""Reset all registered callbacks."""
|
||||
for event in callbacks.default_callbacks.keys():
|
||||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
@ -1,357 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
||||
|
||||
Usage - sources:
|
||||
$ yolo mode=predict model=yolov8n.pt source=0 # webcam
|
||||
img.jpg # image
|
||||
vid.mp4 # video
|
||||
screen # screenshot
|
||||
path/ # directory
|
||||
list.txt # list of images
|
||||
list.streams # list of streams
|
||||
'path/*.jpg' # glob
|
||||
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
||||
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
||||
|
||||
Usage - formats:
|
||||
$ yolo mode=predict model=yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data import load_inference_source
|
||||
from ultralytics.yolo.data.augment import LetterBox, classify_transforms
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
STREAM_WARNING = """
|
||||
WARNING ⚠️ stream/video/webcam/dir predict source will accumulate results in RAM unless `stream=True` is passed,
|
||||
causing potential out-of-memory errors for large sources or long-running streams/videos.
|
||||
|
||||
Usage:
|
||||
results = model(source=..., stream=True) # generator of Results objects
|
||||
for r in results:
|
||||
boxes = r.boxes # Boxes object for bbox outputs
|
||||
masks = r.masks # Masks object for segment masks outputs
|
||||
probs = r.probs # Class probabilities for classification outputs
|
||||
"""
|
||||
|
||||
|
||||
class BasePredictor:
|
||||
"""
|
||||
BasePredictor
|
||||
|
||||
A base class for creating predictors.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the predictor.
|
||||
save_dir (Path): Directory to save results.
|
||||
done_warmup (bool): Whether the predictor has finished setup.
|
||||
model (nn.Module): Model used for prediction.
|
||||
data (dict): Data configuration.
|
||||
device (torch.device): Device used for prediction.
|
||||
dataset (Dataset): Dataset used for prediction.
|
||||
vid_path (str): Path to video file.
|
||||
vid_writer (cv2.VideoWriter): Video writer for saving video output.
|
||||
data_path (str): Path to data.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initializes the BasePredictor class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.save_dir = self.get_save_dir()
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.25 # default conf=0.25
|
||||
self.done_warmup = False
|
||||
if self.args.show:
|
||||
self.args.show = check_imshow(warn=True)
|
||||
|
||||
# Usable if setup is done
|
||||
self.model = None
|
||||
self.data = self.args.data # data_dict
|
||||
self.imgsz = None
|
||||
self.device = None
|
||||
self.dataset = None
|
||||
self.vid_path, self.vid_writer = None, None
|
||||
self.plotted_img = None
|
||||
self.data_path = None
|
||||
self.source_type = None
|
||||
self.batch = None
|
||||
self.results = None
|
||||
self.transforms = None
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def get_save_dir(self):
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
return increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
|
||||
def preprocess(self, im):
|
||||
"""Prepares input image before inference.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
||||
"""
|
||||
not_tensor = not isinstance(im, torch.Tensor)
|
||||
if not_tensor:
|
||||
im = np.stack(self.pre_transform(im))
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
im = torch.from_numpy(im)
|
||||
|
||||
img = im.to(self.device)
|
||||
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
if not_tensor:
|
||||
img /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
return img
|
||||
|
||||
def inference(self, im, *args, **kwargs):
|
||||
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
|
||||
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""Pre-transform input image before inference.
|
||||
|
||||
Args:
|
||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
|
||||
Return: A list of transformed imgs.
|
||||
"""
|
||||
same_shapes = all(x.shape == im[0].shape for x in im)
|
||||
auto = same_shapes and self.model.pt
|
||||
return [LetterBox(self.imgsz, auto=auto, stride=self.model.stride)(image=x) for x in im]
|
||||
|
||||
def write_results(self, idx, results, batch):
|
||||
"""Write inference results to a file or directory."""
|
||||
p, im, _ = batch
|
||||
log_string = ''
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
frame = getattr(self.dataset, 'frame', 0)
|
||||
self.data_path = p
|
||||
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
||||
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||
result = results[idx]
|
||||
log_string += result.verbose()
|
||||
|
||||
if self.args.save or self.args.show: # Add bbox to image
|
||||
plot_args = {
|
||||
'line_width': self.args.line_width,
|
||||
'boxes': self.args.boxes,
|
||||
'conf': self.args.show_conf,
|
||||
'labels': self.args.show_labels}
|
||||
if not self.args.retina_masks:
|
||||
plot_args['im_gpu'] = im[idx]
|
||||
self.plotted_img = result.plot(**plot_args)
|
||||
# Write
|
||||
if self.args.save_txt:
|
||||
result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
|
||||
if self.args.save_crop:
|
||||
result.save_crop(save_dir=self.save_dir / 'crops', file_name=self.data_path.stem)
|
||||
|
||||
return log_string
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Post-processes predictions for an image and returns them."""
|
||||
return preds
|
||||
|
||||
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
|
||||
"""Performs inference on an image or stream."""
|
||||
self.stream = stream
|
||||
if stream:
|
||||
return self.stream_inference(source, model, *args, **kwargs)
|
||||
else:
|
||||
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
|
||||
|
||||
def predict_cli(self, source=None, model=None):
|
||||
"""Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
|
||||
gen = self.stream_inference(source, model)
|
||||
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
||||
pass
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Sets up source and inference mode."""
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
self.transforms = getattr(self.model.model, 'transforms', classify_transforms(
|
||||
self.imgsz[0])) if self.args.task == 'classify' else None
|
||||
self.dataset = load_inference_source(source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride)
|
||||
self.source_type = self.dataset.source_type
|
||||
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
|
||||
len(self.dataset) > 1000 or # images
|
||||
any(getattr(self.dataset, 'video_flag', [False]))): # videos
|
||||
LOGGER.warning(STREAM_WARNING)
|
||||
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
||||
|
||||
@smart_inference_mode()
|
||||
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
||||
"""Streams real-time inference on camera feed and saves results to file."""
|
||||
if self.args.verbose:
|
||||
LOGGER.info('')
|
||||
|
||||
# Setup model
|
||||
if not self.model:
|
||||
self.setup_model(model)
|
||||
|
||||
# Setup source every time predict is called
|
||||
self.setup_source(source if source is not None else self.args.source)
|
||||
|
||||
# Check if save_dir/ label file exists
|
||||
if self.args.save or self.args.save_txt:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Warmup model
|
||||
if not self.done_warmup:
|
||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
|
||||
self.done_warmup = True
|
||||
|
||||
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
self.run_callbacks('on_predict_start')
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks('on_predict_batch_start')
|
||||
self.batch = batch
|
||||
path, im0s, vid_cap, s = batch
|
||||
|
||||
# Preprocess
|
||||
with profilers[0]:
|
||||
im = self.preprocess(im0s)
|
||||
|
||||
# Inference
|
||||
with profilers[1]:
|
||||
preds = self.inference(im, *args, **kwargs)
|
||||
|
||||
# Postprocess
|
||||
with profilers[2]:
|
||||
self.results = self.postprocess(preds, im, im0s)
|
||||
self.run_callbacks('on_predict_postprocess_end')
|
||||
|
||||
# Visualize, save, write results
|
||||
n = len(im0s)
|
||||
for i in range(n):
|
||||
self.seen += 1
|
||||
self.results[i].speed = {
|
||||
'preprocess': profilers[0].dt * 1E3 / n,
|
||||
'inference': profilers[1].dt * 1E3 / n,
|
||||
'postprocess': profilers[2].dt * 1E3 / n}
|
||||
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
|
||||
p = Path(p)
|
||||
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
s += self.write_results(i, self.results, (p, im, im0))
|
||||
if self.args.save or self.args.save_txt:
|
||||
self.results[i].save_dir = self.save_dir.__str__()
|
||||
if self.args.show and self.plotted_img is not None:
|
||||
self.show(p)
|
||||
if self.args.save and self.plotted_img is not None:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
|
||||
self.run_callbacks('on_predict_batch_end')
|
||||
yield from self.results
|
||||
|
||||
# Print time (inference-only)
|
||||
if self.args.verbose:
|
||||
LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
|
||||
|
||||
# Release assets
|
||||
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
|
||||
self.vid_writer[-1].release() # release final video writer
|
||||
|
||||
# Print results
|
||||
if self.args.verbose and self.seen:
|
||||
t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
|
||||
LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
|
||||
f'{(1, 3, *im.shape[2:])}' % t)
|
||||
if self.args.save or self.args.save_txt or self.args.save_crop:
|
||||
nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
|
||||
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
|
||||
self.run_callbacks('on_predict_end')
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
||||
self.model = AutoBackend(model or self.args.model,
|
||||
device=select_device(self.args.device, verbose=verbose),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half,
|
||||
fuse=True,
|
||||
verbose=verbose)
|
||||
|
||||
self.device = self.model.device # update device
|
||||
self.args.half = self.model.fp16 # update half
|
||||
self.model.eval()
|
||||
|
||||
def show(self, p):
|
||||
"""Display an image in a window using OpenCV imshow()."""
|
||||
im0 = self.plotted_img
|
||||
if platform.system() == 'Linux' and p not in self.windows:
|
||||
self.windows.append(p)
|
||||
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
||||
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
||||
cv2.imshow(str(p), im0)
|
||||
cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond
|
||||
|
||||
def save_preds(self, vid_cap, idx, save_path):
|
||||
"""Save video predictions as mp4 at specified path."""
|
||||
im0 = self.plotted_img
|
||||
# Save imgs
|
||||
if self.dataset.mode == 'image':
|
||||
cv2.imwrite(save_path, im0)
|
||||
else: # 'video' or 'stream'
|
||||
if self.vid_path[idx] != save_path: # new video
|
||||
self.vid_path[idx] = save_path
|
||||
if isinstance(self.vid_writer[idx], cv2.VideoWriter):
|
||||
self.vid_writer[idx].release() # release previous video writer
|
||||
if vid_cap: # video
|
||||
fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec
|
||||
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
else: # stream
|
||||
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
||||
suffix = '.mp4' if MACOS else '.avi' if WINDOWS else '.avi'
|
||||
fourcc = 'avc1' if MACOS else 'WMV2' if WINDOWS else 'MJPG'
|
||||
save_path = str(Path(save_path).with_suffix(suffix))
|
||||
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
|
||||
self.vid_writer[idx].write(im0)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Runs all registered callbacks for a specific event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def add_callback(self, event: str, func):
|
||||
"""
|
||||
Add callback
|
||||
"""
|
||||
self.callbacks[event].append(func)
|
@ -1,614 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Ultralytics Results, Boxes and Masks classes for handling inference results
|
||||
|
||||
Usage: See https://docs.ultralytics.com/modes/predict/
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data.augment import LetterBox
|
||||
from ultralytics.yolo.utils import LOGGER, SimpleClass, deprecation_warn, ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||
|
||||
|
||||
class BaseTensor(SimpleClass):
|
||||
"""
|
||||
Base tensor class with additional methods for easy manipulation and device handling.
|
||||
"""
|
||||
|
||||
def __init__(self, data, orig_shape) -> None:
|
||||
"""Initialize BaseTensor with data and original shape.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
|
||||
orig_shape (tuple): Original shape of image.
|
||||
"""
|
||||
assert isinstance(data, (torch.Tensor, np.ndarray))
|
||||
self.data = data
|
||||
self.orig_shape = orig_shape
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Return the shape of the data tensor."""
|
||||
return self.data.shape
|
||||
|
||||
def cpu(self):
|
||||
"""Return a copy of the tensor on CPU memory."""
|
||||
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
"""Return a copy of the tensor as a numpy array."""
|
||||
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
"""Return a copy of the tensor on GPU memory."""
|
||||
return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Return a copy of the tensor with the specified device and dtype."""
|
||||
return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)
|
||||
|
||||
def __len__(self): # override len(results)
|
||||
"""Return the length of the data tensor."""
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return a BaseTensor with the specified index of the data tensor."""
|
||||
return self.__class__(self.data[idx], self.orig_shape)
|
||||
|
||||
|
||||
class Results(SimpleClass):
|
||||
"""
|
||||
A class for storing and manipulating inference results.
|
||||
|
||||
Args:
|
||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||
path (str): The path to the image file.
|
||||
names (dict): A dictionary of class names.
|
||||
boxes (torch.tensor, optional): A 2D tensor of bounding box coordinates for each detection.
|
||||
masks (torch.tensor, optional): A 3D tensor of detection masks, where each mask is a binary image.
|
||||
probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task.
|
||||
keypoints (List[List[float]], optional): A list of detected keypoints for each object.
|
||||
|
||||
|
||||
Attributes:
|
||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||
orig_shape (tuple): The original image shape in (height, width) format.
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (Probs, optional): A Probs object containing probabilities of each class for classification task.
|
||||
names (dict): A dictionary of class names.
|
||||
path (str): The path to the image file.
|
||||
keypoints (Keypoints, optional): A Keypoints object containing detected keypoints for each object.
|
||||
speed (dict): A dictionary of preprocess, inference and postprocess speeds in milliseconds per image.
|
||||
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
||||
"""
|
||||
|
||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
|
||||
"""Initialize the Results class."""
|
||||
self.orig_img = orig_img
|
||||
self.orig_shape = orig_img.shape[:2]
|
||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = Probs(probs) if probs is not None else None
|
||||
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
|
||||
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
|
||||
self.names = names
|
||||
self.path = path
|
||||
self.save_dir = None
|
||||
self._keys = ('boxes', 'masks', 'probs', 'keypoints')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return a Results object for the specified index."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k)[idx])
|
||||
return r
|
||||
|
||||
def update(self, boxes=None, masks=None, probs=None):
|
||||
"""Update the boxes, masks, and probs attributes of the Results object."""
|
||||
if boxes is not None:
|
||||
self.boxes = Boxes(boxes, self.orig_shape)
|
||||
if masks is not None:
|
||||
self.masks = Masks(masks, self.orig_shape)
|
||||
if probs is not None:
|
||||
self.probs = probs
|
||||
|
||||
def cpu(self):
|
||||
"""Return a copy of the Results object with all tensors on CPU memory."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cpu())
|
||||
return r
|
||||
|
||||
def numpy(self):
|
||||
"""Return a copy of the Results object with all tensors as numpy arrays."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).numpy())
|
||||
return r
|
||||
|
||||
def cuda(self):
|
||||
"""Return a copy of the Results object with all tensors on GPU memory."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cuda())
|
||||
return r
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Return a copy of the Results object with tensors on the specified device and dtype."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).to(*args, **kwargs))
|
||||
return r
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of detections in the Results object."""
|
||||
for k in self.keys:
|
||||
return len(getattr(self, k))
|
||||
|
||||
def new(self):
|
||||
"""Return a new Results object with the same image, path, and names."""
|
||||
return Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Return a list of non-empty attribute names."""
|
||||
return [k for k in self._keys if getattr(self, k) is not None]
|
||||
|
||||
def plot(
|
||||
self,
|
||||
conf=True,
|
||||
line_width=None,
|
||||
font_size=None,
|
||||
font='Arial.ttf',
|
||||
pil=False,
|
||||
img=None,
|
||||
im_gpu=None,
|
||||
kpt_line=True,
|
||||
labels=True,
|
||||
boxes=True,
|
||||
masks=True,
|
||||
probs=True,
|
||||
**kwargs # deprecated args TODO: remove support in 8.2
|
||||
):
|
||||
"""
|
||||
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
||||
|
||||
Args:
|
||||
conf (bool): Whether to plot the detection confidence score.
|
||||
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
|
||||
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
|
||||
font (str): The font to use for the text.
|
||||
pil (bool): Whether to return the image as a PIL Image.
|
||||
img (numpy.ndarray): Plot to another image. if not, plot to original image.
|
||||
im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
|
||||
kpt_line (bool): Whether to draw lines connecting keypoints.
|
||||
labels (bool): Whether to plot the label of bounding boxes.
|
||||
boxes (bool): Whether to plot the bounding boxes.
|
||||
masks (bool): Whether to plot the masks.
|
||||
probs (bool): Whether to plot classification probability
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray): A numpy array of the annotated image.
|
||||
"""
|
||||
if img is None and isinstance(self.orig_img, torch.Tensor):
|
||||
img = np.ascontiguousarray(self.orig_img[0].permute(1, 2, 0).cpu().detach().numpy()) * 255
|
||||
|
||||
# Deprecation warn TODO: remove in 8.2
|
||||
if 'show_conf' in kwargs:
|
||||
deprecation_warn('show_conf', 'conf')
|
||||
conf = kwargs['show_conf']
|
||||
assert type(conf) == bool, '`show_conf` should be of boolean type, i.e, show_conf=True/False'
|
||||
|
||||
if 'line_thickness' in kwargs:
|
||||
deprecation_warn('line_thickness', 'line_width')
|
||||
line_width = kwargs['line_thickness']
|
||||
assert type(line_width) == int, '`line_width` should be of int type, i.e, line_width=3'
|
||||
|
||||
names = self.names
|
||||
pred_boxes, show_boxes = self.boxes, boxes
|
||||
pred_masks, show_masks = self.masks, masks
|
||||
pred_probs, show_probs = self.probs, probs
|
||||
annotator = Annotator(
|
||||
deepcopy(self.orig_img if img is None else img),
|
||||
line_width,
|
||||
font_size,
|
||||
font,
|
||||
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
|
||||
example=names)
|
||||
|
||||
# Plot Segment results
|
||||
if pred_masks and show_masks:
|
||||
if im_gpu is None:
|
||||
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
||||
im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
|
||||
2, 0, 1).flip(0).contiguous() / 255
|
||||
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
|
||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
|
||||
|
||||
# Plot Detect results
|
||||
if pred_boxes and show_boxes:
|
||||
for d in reversed(pred_boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
||||
name = ('' if id is None else f'id:{id} ') + names[c]
|
||||
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
|
||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
|
||||
# Plot Classify results
|
||||
if pred_probs is not None and show_probs:
|
||||
text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
|
||||
x = round(self.orig_shape[0] * 0.03)
|
||||
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||
|
||||
# Plot Pose results
|
||||
if self.keypoints is not None:
|
||||
for k in reversed(self.keypoints.data):
|
||||
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
|
||||
|
||||
return annotator.result()
|
||||
|
||||
def verbose(self):
|
||||
"""
|
||||
Return log string for each task.
|
||||
"""
|
||||
log_string = ''
|
||||
probs = self.probs
|
||||
boxes = self.boxes
|
||||
if len(self) == 0:
|
||||
return log_string if probs is not None else f'{log_string}(no detections), '
|
||||
if probs is not None:
|
||||
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
|
||||
if boxes:
|
||||
for c in boxes.cls.unique():
|
||||
n = (boxes.cls == c).sum() # detections per class
|
||||
log_string += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "
|
||||
return log_string
|
||||
|
||||
def save_txt(self, txt_file, save_conf=False):
|
||||
"""
|
||||
Save predictions into txt file.
|
||||
|
||||
Args:
|
||||
txt_file (str): txt file path.
|
||||
save_conf (bool): save confidence score or not.
|
||||
"""
|
||||
boxes = self.boxes
|
||||
masks = self.masks
|
||||
probs = self.probs
|
||||
kpts = self.keypoints
|
||||
texts = []
|
||||
if probs is not None:
|
||||
# Classify
|
||||
[texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
|
||||
elif boxes:
|
||||
# Detect/segment/pose
|
||||
for j, d in enumerate(boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
|
||||
line = (c, *d.xywhn.view(-1))
|
||||
if masks:
|
||||
seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
|
||||
line = (c, *seg)
|
||||
if kpts is not None:
|
||||
kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
|
||||
line += (*kpt.reshape(-1).tolist(), )
|
||||
line += (conf, ) * save_conf + (() if id is None else (id, ))
|
||||
texts.append(('%g ' * len(line)).rstrip() % line)
|
||||
|
||||
if texts:
|
||||
with open(txt_file, 'a') as f:
|
||||
f.writelines(text + '\n' for text in texts)
|
||||
|
||||
def save_crop(self, save_dir, file_name=Path('im.jpg')):
|
||||
"""
|
||||
Save cropped predictions to `save_dir/cls/file_name.jpg`.
|
||||
|
||||
Args:
|
||||
save_dir (str | pathlib.Path): Save path.
|
||||
file_name (str | pathlib.Path): File name.
|
||||
"""
|
||||
if self.probs is not None:
|
||||
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
|
||||
return
|
||||
if isinstance(save_dir, str):
|
||||
save_dir = Path(save_dir)
|
||||
if isinstance(file_name, str):
|
||||
file_name = Path(file_name)
|
||||
for d in self.boxes:
|
||||
save_one_box(d.xyxy,
|
||||
self.orig_img.copy(),
|
||||
file=save_dir / self.names[int(d.cls)] / f'{file_name.stem}.jpg',
|
||||
BGR=True)
|
||||
|
||||
def pandas(self):
|
||||
"""Convert the object to a pandas DataFrame (not yet implemented)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Results.pandas' method is not yet implemented.")
|
||||
|
||||
def tojson(self, normalize=False):
|
||||
"""Convert the object to JSON format."""
|
||||
if self.probs is not None:
|
||||
LOGGER.warning('Warning: Classify task do not support `tojson` yet.')
|
||||
return
|
||||
|
||||
import json
|
||||
|
||||
# Create list of detection dictionaries
|
||||
results = []
|
||||
data = self.boxes.data.cpu().tolist()
|
||||
h, w = self.orig_shape if normalize else (1, 1)
|
||||
for i, row in enumerate(data):
|
||||
box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h}
|
||||
conf = row[4]
|
||||
id = int(row[5])
|
||||
name = self.names[id]
|
||||
result = {'name': name, 'class': id, 'confidence': conf, 'box': box}
|
||||
if self.masks:
|
||||
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
|
||||
result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
|
||||
if self.keypoints is not None:
|
||||
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
|
||||
result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
|
||||
results.append(result)
|
||||
|
||||
# Convert detections to JSON
|
||||
return json.dumps(results, indent=2)
|
||||
|
||||
|
||||
class Boxes(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection boxes.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||
with shape (num_boxes, 6). The last two columns should contain confidence and class values.
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
boxes (torch.Tensor | numpy.ndarray): The detection boxes with shape (num_boxes, 6).
|
||||
orig_shape (torch.Tensor | numpy.ndarray): Original image size, in the format (height, width).
|
||||
is_track (bool): True if the boxes also include track IDs, False otherwise.
|
||||
|
||||
Properties:
|
||||
xyxy (torch.Tensor | numpy.ndarray): The boxes in xyxy format.
|
||||
conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
|
||||
cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
|
||||
id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
|
||||
xywh (torch.Tensor | numpy.ndarray): The boxes in xywh format.
|
||||
xyxyn (torch.Tensor | numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
||||
xywhn (torch.Tensor | numpy.ndarray): The boxes in xywh format normalized by original image size.
|
||||
data (torch.Tensor): The raw bboxes tensor
|
||||
|
||||
Methods:
|
||||
cpu(): Move the object to CPU memory.
|
||||
numpy(): Convert the object to a numpy array.
|
||||
cuda(): Move the object to CUDA memory.
|
||||
to(*args, **kwargs): Move the object to the specified device.
|
||||
pandas(): Convert the object to a pandas DataFrame (not yet implemented).
|
||||
"""
|
||||
|
||||
def __init__(self, boxes, orig_shape) -> None:
|
||||
"""Initialize the Boxes class."""
|
||||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
n = boxes.shape[-1]
|
||||
assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
|
||||
super().__init__(boxes, orig_shape)
|
||||
self.is_track = n == 7
|
||||
self.orig_shape = orig_shape
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
"""Return the boxes in xyxy format."""
|
||||
return self.data[:, :4]
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
"""Return the confidence values of the boxes."""
|
||||
return self.data[:, -2]
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
"""Return the class values of the boxes."""
|
||||
return self.data[:, -1]
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
"""Return the track IDs of the boxes (if available)."""
|
||||
return self.data[:, -3] if self.is_track else None
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2) # maxsize 1 should suffice
|
||||
def xywh(self):
|
||||
"""Return the boxes in xywh format."""
|
||||
return ops.xyxy2xywh(self.xyxy)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxyn(self):
|
||||
"""Return the boxes in xyxy format normalized by original image size."""
|
||||
xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy)
|
||||
xyxy[..., [0, 2]] /= self.orig_shape[1]
|
||||
xyxy[..., [1, 3]] /= self.orig_shape[0]
|
||||
return xyxy
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xywhn(self):
|
||||
"""Return the boxes in xywh format normalized by original image size."""
|
||||
xywh = ops.xyxy2xywh(self.xyxy)
|
||||
xywh[..., [0, 2]] /= self.orig_shape[1]
|
||||
xywh[..., [1, 3]] /= self.orig_shape[0]
|
||||
return xywh
|
||||
|
||||
@property
|
||||
def boxes(self):
|
||||
"""Return the raw bboxes tensor (deprecated)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Boxes.boxes' is deprecated. Use 'Boxes.data' instead.")
|
||||
return self.data
|
||||
|
||||
|
||||
class Masks(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection masks.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor | np.ndarray): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
masks (torch.Tensor | np.ndarray): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
xy (list): A list of segments (pixels) which includes x, y segments of each detection.
|
||||
xyn (list): A list of segments (normalized) which includes x, y segments of each detection.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the masks tensor on CPU memory.
|
||||
numpy(): Returns a copy of the masks tensor as a numpy array.
|
||||
cuda(): Returns a copy of the masks tensor on GPU memory.
|
||||
to(): Returns a copy of the masks tensor with the specified device and dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, masks, orig_shape) -> None:
|
||||
"""Initialize the Masks class."""
|
||||
if masks.ndim == 2:
|
||||
masks = masks[None, :]
|
||||
super().__init__(masks, orig_shape)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def segments(self):
|
||||
"""Return segments (deprecated; normalized)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and "
|
||||
"'Masks.xy' for segments (pixels) instead.")
|
||||
return self.xyn
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xyn(self):
|
||||
"""Return segments (normalized)."""
|
||||
return [
|
||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
|
||||
for x in ops.masks2segments(self.data)]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xy(self):
|
||||
"""Return segments (pixels)."""
|
||||
return [
|
||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
|
||||
for x in ops.masks2segments(self.data)]
|
||||
|
||||
@property
|
||||
def masks(self):
|
||||
"""Return the raw masks tensor (deprecated)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.masks' is deprecated. Use 'Masks.data' instead.")
|
||||
return self.data
|
||||
|
||||
def pandas(self):
|
||||
"""Convert the object to a pandas DataFrame (not yet implemented)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.pandas' method is not yet implemented.")
|
||||
|
||||
|
||||
class Keypoints(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection keypoints.
|
||||
|
||||
Args:
|
||||
keypoints (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_dets, num_kpts, 2/3).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
keypoints (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_dets, num_kpts, 2/3).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
xy (list): A list of keypoints (pixels) which includes x, y keypoints of each detection.
|
||||
xyn (list): A list of keypoints (normalized) which includes x, y keypoints of each detection.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the keypoints tensor on CPU memory.
|
||||
numpy(): Returns a copy of the keypoints tensor as a numpy array.
|
||||
cuda(): Returns a copy of the keypoints tensor on GPU memory.
|
||||
to(): Returns a copy of the keypoints tensor with the specified device and dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, keypoints, orig_shape) -> None:
|
||||
if keypoints.ndim == 2:
|
||||
keypoints = keypoints[None, :]
|
||||
super().__init__(keypoints, orig_shape)
|
||||
self.has_visible = self.data.shape[-1] == 3
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xy(self):
|
||||
return self.data[..., :2]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xyn(self):
|
||||
xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
|
||||
xy[..., 0] /= self.orig_shape[1]
|
||||
xy[..., 1] /= self.orig_shape[0]
|
||||
return xy
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def conf(self):
|
||||
return self.data[..., 2] if self.has_visible else None
|
||||
|
||||
|
||||
class Probs(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating classify predictions.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_class, ).
|
||||
|
||||
Attributes:
|
||||
probs (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_class).
|
||||
|
||||
Properties:
|
||||
top5 (list[int]): Top 1 indice.
|
||||
top1 (int): Top 5 indices.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the probs tensor on CPU memory.
|
||||
numpy(): Returns a copy of the probs tensor as a numpy array.
|
||||
cuda(): Returns a copy of the probs tensor on GPU memory.
|
||||
to(): Returns a copy of the probs tensor with the specified device and dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, probs, orig_shape=None) -> None:
|
||||
super().__init__(probs, orig_shape)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top5(self):
|
||||
"""Return the indices of top 5."""
|
||||
return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy.
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top1(self):
|
||||
"""Return the indices of top 1."""
|
||||
return int(self.data.argmax())
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top5conf(self):
|
||||
"""Return the confidences of top 5."""
|
||||
return self.data[self.top5]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top1conf(self):
|
||||
"""Return the confidences of top 1."""
|
||||
return self.data[self.top1]
|
@ -1,664 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Train a model on a dataset
|
||||
|
||||
Usage:
|
||||
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch import nn, optim
|
||||
from torch.cuda import amp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks,
|
||||
clean_url, colorstr, emojis, yaml_save)
|
||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.yolo.utils.checks import check_amp, check_file, check_imgsz, print_args
|
||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle,
|
||||
select_device, strip_optimizer)
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
"""
|
||||
BaseTrainer
|
||||
|
||||
A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the trainer.
|
||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||
validator (BaseValidator): Validator instance.
|
||||
model (nn.Module): Model instance.
|
||||
callbacks (defaultdict): Dictionary of callbacks.
|
||||
save_dir (Path): Directory to save results.
|
||||
wdir (Path): Directory to save weights.
|
||||
last (Path): Path to last checkpoint.
|
||||
best (Path): Path to best checkpoint.
|
||||
save_period (int): Save checkpoint every x epochs (disabled if < 1).
|
||||
batch_size (int): Batch size for training.
|
||||
epochs (int): Number of epochs to train for.
|
||||
start_epoch (int): Starting epoch for training.
|
||||
device (torch.device): Device to use for training.
|
||||
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
|
||||
scaler (amp.GradScaler): Gradient scaler for AMP.
|
||||
data (str): Path to data.
|
||||
trainset (torch.utils.data.Dataset): Training dataset.
|
||||
testset (torch.utils.data.Dataset): Testing dataset.
|
||||
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
||||
lf (nn.Module): Loss function.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
||||
best_fitness (float): The best fitness value achieved.
|
||||
fitness (float): Current fitness value.
|
||||
loss (float): Current loss value.
|
||||
tloss (float): Total loss value.
|
||||
loss_names (list): List of loss names.
|
||||
csv (Path): Path to results CSV file.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.device = select_device(self.args.device, self.args.batch)
|
||||
self.check_resume()
|
||||
self.validator = None
|
||||
self.model = None
|
||||
self.metrics = None
|
||||
self.plots = {}
|
||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||
|
||||
# Dirs
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
if hasattr(self.args, 'save_dir'):
|
||||
self.save_dir = Path(self.args.save_dir)
|
||||
else:
|
||||
self.save_dir = Path(
|
||||
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
|
||||
self.wdir = self.save_dir / 'weights' # weights dir
|
||||
if RANK in (-1, 0):
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.args.save_dir = str(self.save_dir)
|
||||
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||
self.save_period = self.args.save_period
|
||||
|
||||
self.batch_size = self.args.batch
|
||||
self.epochs = self.args.epochs
|
||||
self.start_epoch = 0
|
||||
if RANK == -1:
|
||||
print_args(vars(self.args))
|
||||
|
||||
# Device
|
||||
if self.device.type == 'cpu':
|
||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||
|
||||
# Model and Dataset
|
||||
self.model = self.args.model
|
||||
try:
|
||||
if self.args.task == 'classify':
|
||||
self.data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
if 'yaml_file' in self.data:
|
||||
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
||||
except Exception as e:
|
||||
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
self.ema = None
|
||||
|
||||
# Optimization utils init
|
||||
self.lf = None
|
||||
self.scheduler = None
|
||||
|
||||
# Epoch level metrics
|
||||
self.best_fitness = None
|
||||
self.fitness = None
|
||||
self.loss = None
|
||||
self.tloss = None
|
||||
self.loss_names = ['Loss']
|
||||
self.csv = self.save_dir / 'results.csv'
|
||||
self.plot_idx = [0, 1, 2]
|
||||
|
||||
# Callbacks
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
if RANK in (-1, 0):
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
Appends the given callback.
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
Overrides the existing callbacks with the given callback.
|
||||
"""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Run all existing callbacks associated with a particular event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def train(self):
|
||||
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
||||
if isinstance(self.args.device, int) or self.args.device: # i.e. device=0 or device=[0,1,2,3]
|
||||
world_size = torch.cuda.device_count()
|
||||
elif torch.cuda.is_available(): # i.e. device=None or device=''
|
||||
world_size = 1 # default to device 0
|
||||
else: # i.e. device='cpu' or 'mps'
|
||||
world_size = 0
|
||||
|
||||
# Run subprocess if DDP training, else train normally
|
||||
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
||||
# Argument checks
|
||||
if self.args.rect:
|
||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting rect=False")
|
||||
self.args.rect = False
|
||||
# Command
|
||||
cmd, file = generate_ddp_command(world_size, self)
|
||||
try:
|
||||
LOGGER.info(f'DDP command: {cmd}')
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
ddp_cleanup(self, str(file))
|
||||
else:
|
||||
self._do_train(world_size)
|
||||
|
||||
def _setup_ddp(self, world_size):
|
||||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||
torch.cuda.set_device(RANK)
|
||||
self.device = torch.device('cuda', RANK)
|
||||
LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
|
||||
dist.init_process_group(
|
||||
'nccl' if dist.is_nccl_available() else 'gloo',
|
||||
timeout=timedelta(seconds=10800), # 3 hours
|
||||
rank=RANK,
|
||||
world_size=world_size)
|
||||
|
||||
def _setup_train(self, world_size):
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
# Model
|
||||
self.run_callbacks('on_pretrain_routine_start')
|
||||
ckpt = self.setup_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.set_model_attributes()
|
||||
# Check AMP
|
||||
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
||||
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||
if RANK > -1 and world_size > 1: # DDP
|
||||
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
||||
self.amp = bool(self.amp) # as boolean
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[RANK])
|
||||
# Check imgsz
|
||||
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
|
||||
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
||||
# Batch size
|
||||
if self.batch_size == -1:
|
||||
if RANK == -1: # single-GPU only, estimate best batch size
|
||||
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
||||
else:
|
||||
SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
|
||||
'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
|
||||
|
||||
# Dataloaders
|
||||
batch_size = self.batch_size // max(world_size, 1)
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
|
||||
if RANK in (-1, 0):
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
|
||||
self.validator = self.get_validator()
|
||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
|
||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||
self.ema = ModelEMA(self.model)
|
||||
if self.args.plots:
|
||||
self.plot_training_labels()
|
||||
|
||||
# Optimizer
|
||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
||||
self.optimizer = self.build_optimizer(model=self.model,
|
||||
name=self.args.optimizer,
|
||||
lr=self.args.lr0,
|
||||
momentum=self.args.momentum,
|
||||
decay=weight_decay,
|
||||
iterations=iterations)
|
||||
# Scheduler
|
||||
if self.args.cos_lr:
|
||||
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||
else:
|
||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||
self.resume_training(ckpt)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
self.run_callbacks('on_pretrain_routine_end')
|
||||
|
||||
def _do_train(self, world_size=1):
|
||||
"""Train completed, evaluate and plot if specified by arguments."""
|
||||
if world_size > 1:
|
||||
self._setup_ddp(world_size)
|
||||
|
||||
self._setup_train(world_size)
|
||||
|
||||
self.epoch_time = None
|
||||
self.epoch_time_start = time.time()
|
||||
self.train_time_start = time.time()
|
||||
nb = len(self.train_loader) # number of batches
|
||||
nw = max(round(self.args.warmup_epochs *
|
||||
nb), 100) if self.args.warmup_epochs > 0 else -1 # number of warmup iterations
|
||||
last_opt_step = -1
|
||||
self.run_callbacks('on_train_start')
|
||||
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f'Starting training for {self.epochs} epochs...')
|
||||
if self.args.close_mosaic:
|
||||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||
epoch = self.epochs # predefine for resume fully trained model edge cases
|
||||
for epoch in range(self.start_epoch, self.epochs):
|
||||
self.epoch = epoch
|
||||
self.run_callbacks('on_train_epoch_start')
|
||||
self.model.train()
|
||||
if RANK != -1:
|
||||
self.train_loader.sampler.set_epoch(epoch)
|
||||
pbar = enumerate(self.train_loader)
|
||||
# Update dataloader attributes (optional)
|
||||
if epoch == (self.epochs - self.args.close_mosaic):
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
self.train_loader.reset()
|
||||
|
||||
if RANK in (-1, 0):
|
||||
LOGGER.info(self.progress_string())
|
||||
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
|
||||
self.tloss = None
|
||||
self.optimizer.zero_grad()
|
||||
for i, batch in pbar:
|
||||
self.run_callbacks('on_train_batch_start')
|
||||
# Warmup
|
||||
ni = i + nb * epoch
|
||||
if ni <= nw:
|
||||
xi = [0, nw] # x interp
|
||||
self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
|
||||
for j, x in enumerate(self.optimizer.param_groups):
|
||||
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||
x['lr'] = np.interp(
|
||||
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
|
||||
if 'momentum' in x:
|
||||
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
||||
|
||||
# Forward
|
||||
with torch.cuda.amp.autocast(self.amp):
|
||||
batch = self.preprocess_batch(batch)
|
||||
self.loss, self.loss_items = self.model(batch)
|
||||
if RANK != -1:
|
||||
self.loss *= world_size
|
||||
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
||||
else self.loss_items
|
||||
|
||||
# Backward
|
||||
self.scaler.scale(self.loss).backward()
|
||||
|
||||
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
||||
if ni - last_opt_step >= self.accumulate:
|
||||
self.optimizer_step()
|
||||
last_opt_step = ni
|
||||
|
||||
# Log
|
||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||
if RANK in (-1, 0):
|
||||
pbar.set_description(
|
||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
|
||||
self.run_callbacks('on_batch_end')
|
||||
if self.args.plots and ni in self.plot_idx:
|
||||
self.plot_training_samples(batch, ni)
|
||||
|
||||
self.run_callbacks('on_train_batch_end')
|
||||
|
||||
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
|
||||
self.scheduler.step()
|
||||
self.run_callbacks('on_train_epoch_end')
|
||||
|
||||
if RANK in (-1, 0):
|
||||
|
||||
# Validation
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
|
||||
|
||||
if self.args.val or final_epoch:
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
||||
self.stop = self.stopper(epoch + 1, self.fitness)
|
||||
|
||||
# Save model
|
||||
if self.args.save or (epoch + 1 == self.epochs):
|
||||
self.save_model()
|
||||
self.run_callbacks('on_model_save')
|
||||
|
||||
tnow = time.time()
|
||||
self.epoch_time = tnow - self.epoch_time_start
|
||||
self.epoch_time_start = tnow
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
|
||||
|
||||
# Early Stopping
|
||||
if RANK != -1: # if DDP training
|
||||
broadcast_list = [self.stop if RANK == 0 else None]
|
||||
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||
if RANK != 0:
|
||||
self.stop = broadcast_list[0]
|
||||
if self.stop:
|
||||
break # must break all DDP ranks
|
||||
|
||||
if RANK in (-1, 0):
|
||||
# Do final val with best.pt
|
||||
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||
self.final_eval()
|
||||
if self.args.plots:
|
||||
self.plot_metrics()
|
||||
self.run_callbacks('on_train_end')
|
||||
torch.cuda.empty_cache()
|
||||
self.run_callbacks('teardown')
|
||||
|
||||
def save_model(self):
|
||||
"""Save model checkpoints based on various conditions."""
|
||||
ckpt = {
|
||||
'epoch': self.epoch,
|
||||
'best_fitness': self.best_fitness,
|
||||
'model': deepcopy(de_parallel(self.model)).half(),
|
||||
'ema': deepcopy(self.ema.ema).half(),
|
||||
'updates': self.ema.updates,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'train_args': vars(self.args), # save as dict
|
||||
'date': datetime.now().isoformat(),
|
||||
'version': __version__}
|
||||
|
||||
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
||||
try:
|
||||
import dill as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, self.last, pickle_module=pickle)
|
||||
if self.best_fitness == self.fitness:
|
||||
torch.save(ckpt, self.best, pickle_module=pickle)
|
||||
if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
|
||||
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
|
||||
del ckpt
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(data):
|
||||
"""
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
"""
|
||||
return data['train'], data.get('val') or data.get('test')
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
load/create/download model for any task.
|
||||
"""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
model, weights = self.model, None
|
||||
ckpt = None
|
||||
if str(model).endswith('.pt'):
|
||||
weights, ckpt = attempt_load_one_weight(model)
|
||||
cfg = ckpt['model'].yaml
|
||||
else:
|
||||
cfg = model
|
||||
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
||||
return ckpt
|
||||
|
||||
def optimizer_step(self):
|
||||
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
|
||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
if self.ema:
|
||||
self.ema.update(self.model)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""
|
||||
Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
"""
|
||||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
if not self.best_fitness or self.best_fitness < fitness:
|
||||
self.best_fitness = fitness
|
||||
return metrics, fitness
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Get model and raise NotImplementedError for loading cfg files."""
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns a NotImplementedError when the get_validator function is called."""
|
||||
raise NotImplementedError('get_validator function not implemented in trainer')
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""
|
||||
Returns dataloader derived from torch.data.Dataloader.
|
||||
"""
|
||||
raise NotImplementedError('get_dataloader function not implemented in trainer')
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
"""Build dataset"""
|
||||
raise NotImplementedError('build_dataset function not implemented in trainer')
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
"""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
return {'loss': loss_items} if loss_items is not None else ['loss']
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""
|
||||
To set or update model parameters before training.
|
||||
"""
|
||||
self.model.names = self.data['names']
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
"""Builds target tensors for training YOLO model."""
|
||||
pass
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a string describing training progress."""
|
||||
return ''
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plots training samples during YOLOv5 training."""
|
||||
pass
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""Plots training labels for YOLO model."""
|
||||
pass
|
||||
|
||||
def save_metrics(self, metrics):
|
||||
"""Saves training metrics to a CSV file."""
|
||||
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||
n = len(metrics) + 1 # number of cols
|
||||
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
|
||||
with open(self.csv, 'a') as f:
|
||||
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plot and display metrics visually."""
|
||||
pass
|
||||
|
||||
def on_plot(self, name, data=None):
|
||||
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||
self.plots[name] = {'data': data, 'timestamp': time.time()}
|
||||
|
||||
def final_eval(self):
|
||||
"""Performs final evaluation and validation for object detection YOLO model."""
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
LOGGER.info(f'\nValidating {f}...')
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop('fitness', None)
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
|
||||
def check_resume(self):
|
||||
"""Check if resume checkpoint exists and update arguments accordingly."""
|
||||
resume = self.args.resume
|
||||
if resume:
|
||||
try:
|
||||
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
|
||||
last = Path(check_file(resume) if exists else get_latest_run())
|
||||
|
||||
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
||||
ckpt_args = attempt_load_weights(last).args
|
||||
if not Path(ckpt_args['data']).exists():
|
||||
ckpt_args['data'] = self.args.data
|
||||
|
||||
self.args = get_cfg(ckpt_args)
|
||||
self.args.model, resume = str(last), True # reinstate
|
||||
except Exception as e:
|
||||
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
|
||||
"i.e. 'yolo train resume model=path/to/last.pt'") from e
|
||||
self.resume = resume
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
"""Resume YOLO training from given epoch and best fitness."""
|
||||
if ckpt is None:
|
||||
return
|
||||
best_fitness = 0.0
|
||||
start_epoch = ckpt['epoch'] + 1
|
||||
if ckpt['optimizer'] is not None:
|
||||
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
||||
best_fitness = ckpt['best_fitness']
|
||||
if self.ema and ckpt.get('ema'):
|
||||
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
||||
self.ema.updates = ckpt['updates']
|
||||
if self.resume:
|
||||
assert start_epoch > 0, \
|
||||
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
|
||||
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
||||
LOGGER.info(
|
||||
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
|
||||
if self.epochs < start_epoch:
|
||||
LOGGER.info(
|
||||
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
|
||||
self.epochs += ckpt['epoch'] # finetune additional epochs
|
||||
self.best_fitness = best_fitness
|
||||
self.start_epoch = start_epoch
|
||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
|
||||
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||
"""
|
||||
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
|
||||
momentum, weight decay, and number of iterations.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model for which to build an optimizer.
|
||||
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
||||
based on the number of iterations. Default: 'auto'.
|
||||
lr (float, optional): The learning rate for the optimizer. Default: 0.001.
|
||||
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
|
||||
decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
|
||||
iterations (float, optional): The number of iterations, which determines the optimizer if
|
||||
name is 'auto'. Default: 1e5.
|
||||
|
||||
Returns:
|
||||
(torch.optim.Optimizer): The constructed optimizer.
|
||||
"""
|
||||
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
if name == 'auto':
|
||||
nc = getattr(model, 'nc', 10) # number of classes
|
||||
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
||||
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
|
||||
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
fullname = f'{module_name}.{param_name}' if module_name else param_name
|
||||
if 'bias' in fullname: # bias (no decay)
|
||||
g[2].append(param)
|
||||
elif isinstance(module, bn): # weight (no decay)
|
||||
g[1].append(param)
|
||||
else: # weight (with decay)
|
||||
g[0].append(param)
|
||||
|
||||
if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
|
||||
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||
elif name == 'RMSProp':
|
||||
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||
elif name == 'SGD':
|
||||
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Optimizer '{name}' not found in list of available optimizers "
|
||||
f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
|
||||
'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
|
||||
|
||||
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
||||
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
||||
LOGGER.info(
|
||||
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
|
||||
return optimizer
|
@ -1,278 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Check a model's accuracy on a test or val split of a dataset
|
||||
|
||||
Usage:
|
||||
$ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
|
||||
|
||||
Usage - formats:
|
||||
$ yolo mode=val model=yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
||||
|
||||
|
||||
class BaseValidator:
|
||||
"""
|
||||
BaseValidator
|
||||
|
||||
A base class for creating validators.
|
||||
|
||||
Attributes:
|
||||
dataloader (DataLoader): Dataloader to use for validation.
|
||||
pbar (tqdm): Progress bar to update during validation.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
model (nn.Module): Model to validate.
|
||||
data (dict): Data dictionary.
|
||||
device (torch.device): Device to use for validation.
|
||||
batch_i (int): Current batch index.
|
||||
training (bool): Whether the model is in training mode.
|
||||
speed (float): Batch processing speed in seconds.
|
||||
jdict (dict): Dictionary to store validation results.
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""
|
||||
Initializes a BaseValidator instance.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
||||
save_dir (Path): Directory to save results.
|
||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.args = args or get_cfg(DEFAULT_CFG)
|
||||
self.model = None
|
||||
self.data = None
|
||||
self.device = None
|
||||
self.batch_i = None
|
||||
self.training = True
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.jdict = None
|
||||
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
self.save_dir = save_dir or increment_path(Path(project) / name,
|
||||
exist_ok=self.args.exist_ok if RANK in (-1, 0) else True)
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.001 # default conf=0.001
|
||||
|
||||
self.plots = {}
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, trainer=None, model=None):
|
||||
"""
|
||||
Supports validation of a pre-trained model if passed or a model being trained
|
||||
if trainer is passed (trainer gets priority).
|
||||
"""
|
||||
self.training = trainer is not None
|
||||
if self.training:
|
||||
self.device = trainer.device
|
||||
self.data = trainer.data
|
||||
model = trainer.ema.ema or trainer.model
|
||||
self.args.half = self.device.type != 'cpu' # force FP16 val during training
|
||||
model = model.half() if self.args.half else model.float()
|
||||
self.model = model
|
||||
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||
self.args.plots = trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
||||
model.eval()
|
||||
else:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
self.run_callbacks('on_val_start')
|
||||
assert model is not None, 'Either trainer or model is needed for validation'
|
||||
model = AutoBackend(model,
|
||||
device=select_device(self.args.device, self.args.batch),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half)
|
||||
self.model = model
|
||||
self.device = model.device # update device
|
||||
self.args.half = model.fp16 # update half
|
||||
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
||||
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
||||
if engine:
|
||||
self.args.batch = model.batch_size
|
||||
elif not pt and not jit:
|
||||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
|
||||
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
elif self.args.task == 'classify':
|
||||
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||
else:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||
|
||||
if self.device.type == 'cpu':
|
||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||
if not pt:
|
||||
self.args.rect = False
|
||||
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
||||
|
||||
model.eval()
|
||||
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
|
||||
|
||||
dt = Profile(), Profile(), Profile(), Profile()
|
||||
n_batches = len(self.dataloader)
|
||||
desc = self.get_desc()
|
||||
# NOTE: keeping `not self.training` in tqdm will eliminate pbar after segmentation evaluation during training,
|
||||
# which may affect classification task since this arg is in yolov5/classify/val.py.
|
||||
# bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
|
||||
bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
|
||||
self.init_metrics(de_parallel(model))
|
||||
self.jdict = [] # empty before each val
|
||||
for batch_i, batch in enumerate(bar):
|
||||
self.run_callbacks('on_val_batch_start')
|
||||
self.batch_i = batch_i
|
||||
# Preprocess
|
||||
with dt[0]:
|
||||
batch = self.preprocess(batch)
|
||||
|
||||
# Inference
|
||||
with dt[1]:
|
||||
preds = model(batch['img'], augment=self.args.augment)
|
||||
|
||||
# Loss
|
||||
with dt[2]:
|
||||
if self.training:
|
||||
self.loss += model.loss(batch, preds)[1]
|
||||
|
||||
# Postprocess
|
||||
with dt[3]:
|
||||
preds = self.postprocess(preds)
|
||||
|
||||
self.update_metrics(preds, batch)
|
||||
if self.args.plots and batch_i < 3:
|
||||
self.plot_val_samples(batch, batch_i)
|
||||
self.plot_predictions(batch, preds, batch_i)
|
||||
|
||||
self.run_callbacks('on_val_batch_end')
|
||||
stats = self.get_stats()
|
||||
self.check_stats(stats)
|
||||
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
|
||||
self.finalize_metrics()
|
||||
self.print_results()
|
||||
self.run_callbacks('on_val_end')
|
||||
if self.training:
|
||||
model.float()
|
||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||
else:
|
||||
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
||||
tuple(self.speed.values()))
|
||||
if self.args.save_json and self.jdict:
|
||||
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
||||
LOGGER.info(f'Saving {f.name}...')
|
||||
json.dump(self.jdict, f) # flatten and save
|
||||
stats = self.eval_json(stats) # update stats
|
||||
if self.args.plots or self.args.save_json:
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
return stats
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""Appends the given callback."""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Runs all callbacks associated with a specified event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
"""Get data loader from dataset path and batch size."""
|
||||
raise NotImplementedError('get_dataloader function not implemented for this validator')
|
||||
|
||||
def build_dataset(self, img_path):
|
||||
"""Build dataset"""
|
||||
raise NotImplementedError('build_dataset function not implemented in validator')
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses an input batch."""
|
||||
return batch
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
|
||||
return preds
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize performance metrics for the YOLO model."""
|
||||
pass
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Updates metrics based on predictions and batch."""
|
||||
pass
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Finalizes and returns all metrics."""
|
||||
pass
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns statistics about the model's performance."""
|
||||
return {}
|
||||
|
||||
def check_stats(self, stats):
|
||||
"""Checks statistics."""
|
||||
pass
|
||||
|
||||
def print_results(self):
|
||||
"""Prints the results of the model's predictions."""
|
||||
pass
|
||||
|
||||
def get_desc(self):
|
||||
"""Get description of the YOLO model."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
"""Returns the metric keys used in YOLO training/validation."""
|
||||
return []
|
||||
|
||||
def on_plot(self, name, data=None):
|
||||
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||
self.plots[name] = {'data': data, 'timestamp': time.time()}
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plots validation samples during training."""
|
||||
pass
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots YOLO model predictions on batch images."""
|
||||
pass
|
||||
|
||||
def pred_to_json(self, preds, batch):
|
||||
"""Convert predictions to JSON format."""
|
||||
pass
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Evaluate and return JSON format of prediction statistics."""
|
||||
pass
|
Reference in New Issue
Block a user