Start export implementation (#110)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent c1b38428bc
commit 92dad1c1b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -92,12 +92,12 @@ jobs:
run: | run: |
yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=1 imgsz=64 yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=1 imgsz=64
yolo task=detect mode=val model=runs/detect/train/weights/last.pt imgsz=64 yolo task=detect mode=val model=runs/detect/train/weights/last.pt imgsz=64
- name: Test segmentation # TODO: segmentation CI - name: Test segmentation
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
# yolo task=segment mode=train model=yolov8n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64 yolo task=segment mode=train model=yolov8n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64
# yolo task=segment mode=val model=runs/segment/train/weights/last.pt data=coco128-seg.yaml imgsz=64 yolo task=segment mode=val model=runs/segment/train/weights/last.pt data=coco128-seg.yaml imgsz=64
- name: Test classification # TODO: change to exp3 on Segmentation CI update - name: Test classification
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32 yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32

@ -1,49 +0,0 @@
# Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]

@ -1,64 +1,16 @@
import torch
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.nn.modules import Detect, Segment
def export_onnx(model, file):
# YOLOv5 ONNX export
import onnx
im = torch.zeros(1, 3, 640, 640)
model.eval()
model(im, profile=True)
for k, m in model.named_modules():
if isinstance(m, (Detect, Segment)):
m.export = True
torch.onnx.export(
model,
im,
file,
verbose=False,
opset_version=12,
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
input_names=['images'])
# Checks
model_onnx = onnx.load(file) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# Metadata
d = {'stride': int(max(model.stride)), 'names': model.names}
for k, v in d.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
onnx.save(model_onnx, file)
if __name__ == "__main__": if __name__ == "__main__":
model = YOLO() YOLO.new("yolov8n.yaml")
print("yolov8n") YOLO.new("yolov8n-seg.yaml")
model.new("yolov8n.yaml") YOLO.new("yolov8s.yaml")
print("yolov8n-seg") YOLO.new("yolov8s-seg.yaml")
model.new("yolov8n-seg.yaml") YOLO.new("yolov8m.yaml")
print("yolov8s") YOLO.new("yolov8m-seg.yaml")
model.new("yolov8s.yaml") YOLO.new("yolov8l.yaml")
# export_onnx(model.model, "yolov8s.onnx") YOLO.new("yolov8l-seg.yaml")
print("yolov8s-seg") YOLO.new("yolov8x.yaml")
model.new("yolov8s-seg.yaml") YOLO.new("yolov8x-seg.yaml")
# export_onnx(model.model, "yolov8s-seg.onnx")
print("yolov8m")
model.new("yolov8m.yaml")
print("yolov8m-seg")
model.new("yolov8m-seg.yaml")
print("yolov8l")
model.new("yolov8l.yaml")
print("yolov8l-seg")
model.new("yolov8l-seg.yaml")
print("yolov8x")
model.new("yolov8x.yaml")
print("yolov8x-seg")
model.new("yolov8x-seg.yaml")
# n vs n-seg: 8.9GFLOPs vs 12.8GFLOPs, 3.16M vs 3.6M. ch[0] // 4 (11.9GFLOPs, 3.39M) # n vs n-seg: 8.9GFLOPs vs 12.8GFLOPs, 3.16M vs 3.6M. ch[0] // 4 (11.9GFLOPs, 3.39M)
# s vs s-seg: 28.8GFLOPs vs 44.4GFLOPs, 11.1M vs 12.9M. ch[0] // 4 (39.5GFLOPs, 11.7M) # s vs s-seg: 28.8GFLOPs vs 44.4GFLOPs, 11.1M vs 12.9M. ch[0] // 4 (39.5GFLOPs, 11.7M)

@ -2,11 +2,9 @@ import cv2
import hydra import hydra
from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.utils import ROOT from ultralytics.yolo.utils import DEFAULT_CONFIG
from ultralytics.yolo.utils.plotting import plot_images from ultralytics.yolo.utils.plotting import plot_images
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
class Colors: class Colors:
# Ultralytics color palette https://ultralytics.com/ # Ultralytics color palette https://ultralytics.com/

@ -2,11 +2,9 @@ import cv2
import hydra import hydra
from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.utils import ROOT from ultralytics.yolo.utils import DEFAULT_CONFIG
from ultralytics.yolo.utils.plotting import plot_images from ultralytics.yolo.utils.plotting import plot_images
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
class Colors: class Colors:
# Ultralytics color palette https://ultralytics.com/ # Ultralytics color palette https://ultralytics.com/

@ -3,11 +3,11 @@ from ultralytics.yolo.utils.checks import check_yaml
def test_model_parser(): def test_model_parser():
cfg = check_yaml("../assets/dummy_model.yaml") # check YAML cfg = check_yaml("yolov8n.yaml") # check YAML
# Create model # Create model
model = DetectionModel(cfg) model = DetectionModel(cfg)
print(model) model.info()
''' '''
# Options # Options
if opt.line_profile: # profile layer by layer if opt.line_profile: # profile layer by layer

@ -62,6 +62,35 @@ def test_model_train_pretrained():
model(img) model(img)
def test_exports():
"""
Format Argument Suffix CPU GPU
0 PyTorch - .pt True True
1 TorchScript torchscript .torchscript True True
2 ONNX onnx .onnx True True
3 OpenVINO openvino _openvino_model True False
4 TensorRT engine .engine False True
5 CoreML coreml .mlmodel True False
6 TensorFlow SavedModel saved_model _saved_model True True
7 TensorFlow GraphDef pb .pb True True
8 TensorFlow Lite tflite .tflite True False
9 TensorFlow Edge TPU edgetpu _edgetpu.tflite False False
10 TensorFlow.js tfjs _web_model False False
11 PaddlePaddle paddle _paddle_model True True
"""
from ultralytics import YOLO
from ultralytics.yolo.engine.exporter import export_formats
print(export_formats())
model = YOLO.new("yolov8n.yaml")
model.export(format='torchscript')
model.export(format='onnx')
model.export(format='openvino')
model.export(format='coreml')
model.export(format='paddle')
def test(): def test():
test_model_forward() test_model_forward()
test_model_info() test_model_info()

@ -19,7 +19,7 @@ from ultralytics.yolo.utils.ops import xywh2xyxy
class AutoBackend(nn.Module): class AutoBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends # YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True): def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
# Usage: # Usage:
# PyTorch: weights = *.pt # PyTorch: weights = *.pt
# TorchScript: *.torchscript # TorchScript: *.torchscript

@ -6,12 +6,12 @@ import thop
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
import yaml
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Segment) GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import LOGGER, colorstr from ultralytics.yolo.utils import LOGGER, colorstr
from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts, from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts,
make_divisible, model_info, scale_img, time_sync) make_divisible, model_info, scale_img, time_sync)
@ -78,14 +78,9 @@ class BaseModel(nn.Module):
class DetectionModel(BaseModel): class DetectionModel(BaseModel):
# YOLOv5 detection model # YOLOv5 detection model
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
super().__init__() super().__init__()
if isinstance(cfg, dict): self.yaml = cfg if isinstance(cfg, dict) else yaml_load(cfg) # cfg dict
self.yaml = cfg # model dict
else: # is *.yaml
self.yaml_file = Path(cfg).name
with open(cfg, encoding='ascii', errors='ignore') as f:
self.yaml = yaml.safe_load(f) # model dict
# Define model # Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
@ -163,7 +158,7 @@ class DetectionModel(BaseModel):
class SegmentationModel(DetectionModel): class SegmentationModel(DetectionModel):
# YOLOv5 segmentation model # YOLOv5 segmentation model
def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, verbose=True): def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
super().__init__(cfg, ch, nc, verbose) super().__init__(cfg, ch, nc, verbose)

@ -1,43 +1,48 @@
import os
import shutil import shutil
from pathlib import Path
import hydra import hydra
import ultralytics import ultralytics
import ultralytics.yolo.v8 as yolo from ultralytics import yolo
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from .utils import LOGGER, colorstr from .utils import DEFAULT_CONFIG, LOGGER, colorstr
@hydra.main(version_base=None, config_path="utils/configs", config_name="default") @hydra.main(version_base=None, config_path="configs", config_name="default")
def cli(cfg): def cli(cfg):
cwd = Path().cwd()
LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}") LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}")
task, mode = cfg.task.lower(), cfg.mode.lower() task, mode = cfg.task.lower(), cfg.mode.lower()
if task == "init": # special case if task == "init": # special case
shutil.copy2(DEFAULT_CONFIG, os.getcwd()) shutil.copy2(DEFAULT_CONFIG, cwd)
LOGGER.info(f""" LOGGER.info(f"""
{colorstr("YOLO :")} configuration saved to {os.getcwd()}/{DEFAULT_CONFIG.name}. {colorstr("YOLO:")} configuration saved to {cwd / DEFAULT_CONFIG.name}.
To run experiments using custom configuration: To run experiments using custom configuration:
yolo task='task' mode='mode' --config-name config_file.yaml yolo task='task' mode='mode' --config-name config_file.yaml
""") """)
return return
elif task == "detect": elif task == "detect":
module_file = yolo.detect module = yolo.v8.detect
elif task == "segment": elif task == "segment":
module_file = yolo.segment module = yolo.v8.segment
elif task == "classify": elif task == "classify":
module_file = yolo.classify module = yolo.v8.classify
elif task == "export":
func = yolo.trainer.exporter.export_model
else: else:
raise SyntaxError("task not recognized. Choices are `'detect', 'segment', 'classify'`") raise SyntaxError("task not recognized. Choices are `'detect', 'segment', 'classify'`")
if mode == "train": if mode == "train":
module_function = module_file.train func = module.train
elif mode == "val": elif mode == "val":
module_function = module_file.val func = module.val
elif mode == "predict": elif mode == "predict":
module_function = module_file.predict func = module.predict
elif mode == "export":
func = yolo.trainer.exporter.export_model
else: else:
raise SyntaxError("mode not recognized. Choices are `'train', 'val', 'predict'`") raise SyntaxError("mode not recognized. Choices are `'train', 'val', 'predict', 'export'`")
module_function(cfg) func(cfg)

@ -3,7 +3,7 @@ from typing import Dict, Union
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from ultralytics.yolo.utils.configs.hydra_patch import check_config_mismatch from ultralytics.yolo.configs.hydra_patch import check_config_mismatch
def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}): def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):

@ -44,11 +44,11 @@ save_hybrid: False
conf_thres: 0.001 conf_thres: 0.001
iou_thres: 0.7 iou_thres: 0.7
max_det: 300 max_det: 300
half: True half: False
dnn: False # use OpenCV DNN for ONNX inference dnn: False # use OpenCV DNN for ONNX inference
plots: True plots: True
# Prediction settings: # Prediction settings --------------------------------------------------------------------------------------------------
source: "ultralytics/assets/" source: "ultralytics/assets/"
view_img: False view_img: False
save_txt: False save_txt: False
@ -64,6 +64,15 @@ augment: False
agnostic_nms: False # class-agnostic NMS agnostic_nms: False # class-agnostic NMS
retina_masks: False retina_masks: False
# Export settings ------------------------------------------------------------------------------------------------------
keras: False # use Keras
optimize: False # TorchScript: optimize for mobile
int8: False # CoreML/TF INT8 quantization
dynamic: False # ONNX/TF/TensorRT: dynamic axes
simplify: False # ONNX: simplify model
opset: 17 # ONNX: opset version
workspace: 4 # TensorRT: workspace size (GB)
# Hyperparameters ------------------------------------------------------------------------------------------------------ # Hyperparameters ------------------------------------------------------------------------------------------------------
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
@ -93,7 +102,7 @@ mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability) copy_paste: 0.0 # segment copy-paste (probability)
# For debugging. Don't change # For debugging. Don't change
v5loader: True v5loader: False
# Hydra configs -------------------------------------------------------------------------------------------------------- # Hydra configs --------------------------------------------------------------------------------------------------------
hydra: hydra:

@ -4,8 +4,8 @@ from textwrap import dedent
import hydra import hydra
from hydra.errors import ConfigCompositionException from hydra.errors import ConfigCompositionException
from omegaconf import OmegaConf, open_dict from omegaconf import OmegaConf, open_dict # noqa
from omegaconf.errors import ConfigAttributeError, ConfigKeyError, OmegaConfBaseException from omegaconf.errors import ConfigAttributeError, ConfigKeyError, OmegaConfBaseException # noqa
from ultralytics.yolo.utils import LOGGER, colorstr from ultralytics.yolo.utils import LOGGER, colorstr
@ -16,8 +16,7 @@ def override_config(overrides, cfg):
for override in overrides: for override in overrides:
if override.package is not None: if override.package is not None:
raise ConfigCompositionException(f"Override {override.input_line} looks like a config group" raise ConfigCompositionException(f"Override {override.input_line} looks like a config group"
f" override, but config group '{override.key_or_group}' does not" f" override, but config group '{override.key_or_group}' does not exist.")
" exist.")
key = override.key_or_group key = override.key_or_group
value = override.value() value = override.value()
@ -37,7 +36,7 @@ def override_config(overrides, cfg):
if last_dot == -1: if last_dot == -1:
del cfg[key] del cfg[key]
else: else:
node = OmegaConf.select(cfg, key[0:last_dot]) node = OmegaConf.select(cfg, key[:last_dot])
del node[key[last_dot + 1:]] del node[key[last_dot + 1:]]
elif override.is_add(): elif override.is_add():
@ -65,10 +64,7 @@ def override_config(overrides, cfg):
def check_config_mismatch(overrides, cfg): def check_config_mismatch(overrides, cfg):
mismatched = [] mismatched = [option for option in overrides if option not in cfg and 'hydra.' not in option]
for option in overrides:
if option not in cfg and 'hydra.' not in option:
mismatched.append(option)
for option in mismatched: for option in mismatched:
LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, cfg, 3, 0.6)}") LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, cfg, 3, 0.6)}")

@ -192,7 +192,7 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
def check_dataset_yaml(data, autodownload=True): def check_dataset_yaml(data, autodownload=True):
# Download, check and/or unzip dataset if not found locally # Download, check and/or unzip dataset if not found locally
data = check_file(data) data = check_file(data)
DATASETS_DIR = Path.cwd() / "../datasets" # TODO: handle global dataset dir DATASETS_DIR = (Path.cwd() / "../datasets").resolve() # TODO: handle global dataset dir
# Download (optional) # Download (optional)
extract_dir = '' extract_dir = ''
if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)): if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):

@ -1,4 +1,77 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
Format | `export.py --include` | Model
--- | --- | ---
PyTorch | - | yolov8n.pt
TorchScript | `torchscript` | yolov8n.torchscript
ONNX | `onnx` | yolov8n.onnx
OpenVINO | `openvino` | yolov5s_openvino_model/
TensorRT | `engine` | yolov8n.engine
CoreML | `coreml` | yolov8n.mlmodel
TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/
TensorFlow GraphDef | `pb` | yolov8n.pb
TensorFlow Lite | `tflite` | yolov8n.tflite
TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
TensorFlow.js | `tfjs` | yolov5s_web_model/
PaddlePaddle | `paddle` | yolov5s_paddle_model/
Requirements:
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
Usage:
$ python export.py --weights yolov8n.pt --include torchscript onnx openvino engine coreml tflite ...
Inference:
$ python detect.py --weights yolov8n.pt # PyTorch
yolov8n.torchscript # TorchScript
yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s_openvino_model # OpenVINO
yolov8n.engine # TensorRT
yolov8n.mlmodel # CoreML (macOS-only)
yolov5s_saved_model # TensorFlow SavedModel
yolov8n.pb # TensorFlow GraphDef
yolov8n.tflite # TensorFlow Lite
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
yolov5s_paddle_model # PaddlePaddle
TensorFlow.js:
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
$ npm install
$ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
$ npm start
from ultralytics import YOLO
model = YOLO().new('yolov8n.yaml')
results = model.export(format='onnx')
"""
import contextlib
import json
import os
import platform
import re
import subprocess
import time
import warnings
from copy import deepcopy
from pathlib import Path
import pandas as pd import pandas as pd
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
from ultralytics.nn.modules import Detect, Segment
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
from ultralytics.yolo.utils import LOGGER, ROOT, colorstr, get_default_args
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version
from ultralytics.yolo.utils.files import file_size, yaml_save
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
MACOS = platform.system() == 'Darwin' # macOS environment
def export_formats(): def export_formats():
@ -17,3 +90,519 @@ def export_formats():
['TensorFlow.js', 'tfjs', '_web_model', False, False], ['TensorFlow.js', 'tfjs', '_web_model', False, False],
['PaddlePaddle', 'paddle', '_paddle_model', True, True],] ['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
def try_export(inner_func):
# YOLOv5 export decorator, i..e @try_export
inner_args = get_default_args(inner_func)
def outer_func(*args, **kwargs):
prefix = inner_args['prefix']
try:
with Profile() as dt:
f, model = inner_func(*args, **kwargs)
LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
return f, model
except Exception as e:
LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
return None, None
return outer_func
@try_export
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
# YOLOv5 TorchScript model export
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
f = file.with_suffix('.torchscript')
ts = torch.jit.trace(model, im, strict=False)
d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
else:
ts.save(str(f), _extra_files=extra_files)
return f, None
@try_export
def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
# YOLOv5 ONNX export
check_requirements('onnx>=1.12.0')
import onnx # noqa
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')
output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0']
if dynamic:
dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
if isinstance(model, SegmentationModel):
dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
elif isinstance(model, DetectionModel):
dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
torch.onnx.export(
model.cpu() if dynamic else model, # --dynamic only compatible with cpu
im.cpu() if dynamic else im,
f,
verbose=False,
opset_version=opset,
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
input_names=['images'],
output_names=output_names,
dynamic_axes=dynamic or None)
# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# Metadata
d = {'stride': int(max(model.stride)), 'names': model.names}
for k, v in d.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
onnx.save(model_onnx, f)
# Simplify
if simplify:
try:
cuda = torch.cuda.is_available()
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
import onnxsim
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
LOGGER.info(f'{prefix} simplifier failure: {e}')
return f, model_onnx
@try_export
def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):
# YOLOv5 OpenVINO export
check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.inference_engine as ie # noqa
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
f = str(file).replace('.pt', f'_openvino_model{os.sep}')
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
subprocess.run(cmd.split(), check=True, env=os.environ) # export
yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
return f, None
@try_export
def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):
# YOLOv5 Paddle export
check_requirements(('paddlepaddle', 'x2paddle'))
import x2paddle # noqa
from x2paddle.convert import pytorch2paddle # noqa
LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
f = str(file).replace('.pt', f'_paddle_model{os.sep}')
pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im]) # export
yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
return f, None
@try_export
def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
# YOLOv5 CoreML export
check_requirements('coremltools')
import coremltools as ct # noqa
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
f = file.with_suffix('.mlmodel')
ts = torch.jit.trace(model, im, strict=False) # TorchScript model
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
if bits < 32:
if MACOS: # quantization only supported on macOS
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
else:
LOGGER.info(f'{prefix} quantization only supported on macOS, skipping...')
ct_model.save(f)
return f, ct_model
@try_export
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
try:
import tensorrt as trt
except Exception:
if platform.system() == 'Linux':
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
import tensorrt as trt
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
model.model[-1].anchor_grid = grid
else: # TensorRT >= 8
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
onnx = file.with_suffix('.onnx')
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
f = file.with_suffix('.engine') # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO)
if verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx)):
raise RuntimeError(f'failed to load ONNX file: {onnx}')
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
for inp in inputs:
LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
for out in outputs:
LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
if dynamic:
if im.shape[0] <= 1:
LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
profile = builder.create_optimization_profile()
for inp in inputs:
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
config.add_optimization_profile(profile)
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
if builder.platform_has_fast_fp16 and half:
config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
t.write(engine.serialize())
return f, None
@try_export
def export_saved_model(model,
im,
file,
dynamic,
tf_nms=False,
agnostic_nms=False,
topk_per_class=100,
topk_all=100,
iou_thres=0.45,
conf_thres=0.25,
keras=False,
prefix=colorstr('TensorFlow SavedModel:')):
# YOLOv5 TensorFlow SavedModel export
try:
import tensorflow as tf
except Exception:
check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
import tensorflow as tf
from models.tf import TFModel
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = str(file).replace('.pt', '_saved_model')
batch_size, ch, *imgsz = list(im.shape) # BCHW
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
keras_model.trainable = False
keras_model.summary()
if keras:
keras_model.save(f, save_format='tf')
else:
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
m = tf.function(lambda x: keras_model(x)) # full model
m = m.get_concrete_function(spec)
frozen_func = convert_variables_to_constants_v2(m)
tfm = tf.Module()
tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec])
tfm.__call__(im)
tf.saved_model.save(tfm,
f,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(
tf.__version__, '2.6') else tf.saved_model.SaveOptions())
return f, keras_model
@try_export
def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
# YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
import tensorflow as tf # noqa
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
f = file.with_suffix('.pb')
m = tf.function(lambda x: keras_model(x)) # full model
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
frozen_func = convert_variables_to_constants_v2(m)
frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
return f, None
@try_export
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
# YOLOv5 TensorFlow Lite export
import tensorflow as tf # noqa
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
batch_size, ch, *imgsz = list(im.shape) # BCHW
f = str(file).replace('.pt', '-fp16.tflite')
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
converter.target_spec.supported_types = [tf.float16]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
if int8:
# from models.tf import representative_dataset_gen
# dataset = LoadImages(check_dataset(check_yaml(data))['train'], imgsz=imgsz, auto=False)
# converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.target_spec.supported_types = []
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
converter.experimental_new_quantizer = True
f = str(file).replace('.pt', '-int8.tflite')
if nms or agnostic_nms:
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
tflite_model = converter.convert()
open(f, "wb").write(tflite_model)
return f, None
@try_export
def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
cmd = 'edgetpu_compiler --version'
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
for c in (
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"
subprocess.run(cmd.split(), check=True)
return f, None
@try_export
def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
# YOLOv5 TensorFlow.js export
check_requirements('tensorflowjs')
import tensorflowjs as tfjs # noqa
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
f = str(file).replace('.pt', '_web_model') # js dir
f_pb = file.with_suffix('.pb') # *.pb path
f_json = f'{f}/model.json' # *.json path
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
subprocess.run(cmd.split())
json = Path(f_json).read_text()
with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
subst = re.sub(
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
r'"Identity_1": {"name": "Identity_1"}, '
r'"Identity_2": {"name": "Identity_2"}, '
r'"Identity_3": {"name": "Identity_3"}}}', json)
j.write(subst)
return f, None
def add_tflite_metadata(file, metadata, num_outputs):
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
with contextlib.suppress(ImportError):
# check_requirements('tflite_support')
from tflite_support import flatbuffers # noqa
from tflite_support import metadata as _metadata # noqa
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
tmp_file = Path('/tmp/meta.txt')
with open(tmp_file, 'w') as meta_f:
meta_f.write(str(metadata))
model_meta = _metadata_fb.ModelMetadataT()
label_file = _metadata_fb.AssociatedFileT()
label_file.name = tmp_file.name
model_meta.associatedFiles = [label_file]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
populator = _metadata.MetadataPopulator.with_model_file(file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files([str(tmp_file)])
populator.populate()
tmp_file.unlink()
@smart_inference_mode()
def export_model(
model, # model
file=ROOT / 'yolov8n.pt',
data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
imgsz=(640, 640), # image (height, width)
batch_size=1, # batch size
device=torch.device('cpu'), # cuda device, i.e. 0 or 0,1,2,3 or cpu
format='onnx', # export format
half=False, # FP16 half-precision export
keras=False, # use Keras
optimize=False, # TorchScript: optimize for mobile
int8=False, # CoreML/TF INT8 quantization
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
simplify=False, # ONNX: simplify model
opset=17, # ONNX: opset version
verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB)
nms=False, # TF: add NMS to model
agnostic_nms=False, # TF: add agnostic NMS to model
topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold
conf_thres=0.25, # TF.js NMS: confidence threshold
):
t = time.time()
format = format.lower() # to lowercase
fmts = tuple(export_formats()['Argument'][1:]) # available export formats
flags = [x == format for x in fmts]
assert sum(flags), f'ERROR: Invalid format={format}, valid formats are {fmts}'
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
# Load PyTorch model
device = select_device(device)
if half:
assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
model = deepcopy(model).fuse() # load FP32 model
# Checks
if isinstance(imgsz, int):
imgsz = [imgsz]
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
if optimize:
assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
# Input
gs = int(max(model.stride)) # grid size (max stride)
imgsz = [check_imgsz(x, gs) for x in imgsz] # verify img_size are gs-multiples
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
# Update model
model.eval()
for k, m in model.named_modules():
if isinstance(m, (Detect, Segment)):
m.dynamic = dynamic
m.export = True
for _ in range(2):
y = model(im) # dry runs
if half and not coreml:
im, model = im.half(), model.half() # to FP16
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
# Warnings
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant type missing ONNX warning
warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
# Exports
f = [''] * len(fmts) # exported filenames
if jit: # TorchScript
f[0], _ = export_torchscript(model, im, file, optimize)
if engine: # TensorRT required before ONNX
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
if onnx or xml: # OpenVINO requires ONNX
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
if xml: # OpenVINO
f[3], _ = export_openvino(file, metadata, half)
if coreml: # CoreML
f[4], _ = export_coreml(model, im, file, int8, half)
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
f[5], s_model = export_saved_model(model.cpu(),
im,
file,
dynamic,
tf_nms=nms or agnostic_nms or tfjs,
agnostic_nms=agnostic_nms or tfjs,
topk_per_class=topk_per_class,
topk_all=topk_all,
iou_thres=iou_thres,
conf_thres=conf_thres,
keras=keras)
if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = export_pb(s_model, file)
if tflite or edgetpu:
f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
if edgetpu:
f[8], _ = export_edgetpu(file)
add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))
if tfjs:
f[9], _ = export_tfjs(file)
if paddle: # PaddlePaddle
f[10], _ = export_paddle(model, im, file, metadata)
# Finish
f = [str(x) for x in f if x] # filter out '' and None
if any(f):
cls, det, seg = (isinstance(model, x) for x in (ClassificationModel, DetectionModel, SegmentationModel)) # type
det &= not seg # segmentation models inherit from SegmentationModel(DetectionModel)
dir = Path('segment' if seg else 'classify' if cls else '')
h = '--half' if half else '' # --half FP16 inference arg
s = "# WARNING ⚠️ ClassificationModel not yet supported for PyTorch Hub AutoShape inference" if cls else \
"# WARNING ⚠️ SegmentationModel not yet supported for PyTorch Hub AutoShape inference" if seg else ''
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nDetect: python {dir / 'predict.py'} --weights {f[-1]} {h}"
f"\nValidate: python {dir / 'val.py'} --weights {f[-1]} {h}"
f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}') {s}"
f"\nVisualize: https://netron.app")
return f # return list of exported files/dirs

@ -1,13 +1,13 @@
from pathlib import Path
import torch import torch
import yaml
from ultralytics import yolo # noqa required for python usage from ultralytics import yolo # noqa required for python usage
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.configs import get_config
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.engine.exporter import export_model
from ultralytics.yolo.utils import HELP_MSG, LOGGER from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER
from ultralytics.yolo.utils.checks import check_yaml from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.configs import get_config
from ultralytics.yolo.utils.files import yaml_load from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.torch_utils import smart_inference_mode from ultralytics.yolo.utils.torch_utils import smart_inference_mode
@ -36,7 +36,7 @@ class YOLO:
type (str): Type/version of models to use type (str): Type/version of models to use
""" """
if init_key != YOLO.__init_key: if init_key != YOLO.__init_key:
raise Exception(HELP_MSG) raise SyntaxError(HELP_MSG)
self.type = type self.type = type
self.ModelClass = None self.ModelClass = None
@ -46,7 +46,8 @@ class YOLO:
self.model = None self.model = None
self.trainer = None self.trainer = None
self.task = None self.task = None
self.ckpt = None self.ckpt = None # if loaded from *.pt
self.cfg = None # if loaded from *.yaml
self.overrides = {} self.overrides = {}
self.init_disabled = False self.init_disabled = False
@ -59,12 +60,12 @@ class YOLO:
cfg (str): model configuration file cfg (str): model configuration file
""" """
cfg = check_yaml(cfg) # check YAML cfg = check_yaml(cfg) # check YAML
with open(cfg, encoding='ascii', errors='ignore') as f: cfg_dict = yaml_load(cfg) # model dict
cfg = yaml.safe_load(f) # model dict
obj = cls(init_key=cls.__init_key) obj = cls(init_key=cls.__init_key)
obj.task = obj._guess_task_from_head(cfg["head"][-1][-2]) obj.task = obj._guess_task_from_head(cfg_dict["head"][-1][-2])
obj.ModelClass, obj.TrainerClass, obj.ValidatorClass, obj.PredictorClass = obj._guess_ops_from_task(obj.task) obj.ModelClass, obj.TrainerClass, obj.ValidatorClass, obj.PredictorClass = obj._guess_ops_from_task(obj.task)
obj.model = obj.ModelClass(cfg) # initialize obj.model = obj.ModelClass(cfg_dict) # initialize
obj.cfg = cfg
return obj return obj
@ -116,13 +117,14 @@ class YOLO:
LOGGER.info("model not initialized!") LOGGER.info("model not initialized!")
self.model.fuse() self.model.fuse()
@smart_inference_mode()
def predict(self, source, **kwargs): def predict(self, source, **kwargs):
""" """
Visualize prection. Visualize prediction.
Args: Args:
source (str): Accepts all source types accepted by yolo source (str): Accepts all source types accepted by yolo
**kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in the docs
""" """
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides.update(kwargs) overrides.update(kwargs)
@ -131,7 +133,7 @@ class YOLO:
# check size type # check size type
sz = predictor.args.imgsz sz = predictor.args.imgsz
if type(sz) != int: # recieved listConfig if type(sz) != int: # received listConfig
predictor.args.imgsz = [sz[0], sz[0]] if len(sz) == 1 else [sz[0], sz[1]] # expand predictor.args.imgsz = [sz[0], sz[0]] if len(sz) == 1 else [sz[0], sz[1]] # expand
else: else:
predictor.args.imgsz = [sz, sz] predictor.args.imgsz = [sz, sz]
@ -139,16 +141,17 @@ class YOLO:
predictor.setup(model=self.model, source=source) predictor.setup(model=self.model, source=source)
predictor() predictor()
@smart_inference_mode()
def val(self, data=None, **kwargs): def val(self, data=None, **kwargs):
""" """
Validate a model on a given dataset Validate a model on a given dataset
Args: Args:
data (str): The dataset to validate on. Accepts all formats accepted by yolo data (str): The dataset to validate on. Accepts all formats accepted by yolo
kwargs: Any other args accepted by the validators. Too see all args check 'configuration' section in the docs kwargs: Any other args accepted by the validators. To see all args check 'configuration' section in the docs
""" """
if not self.model: if not self.model:
raise Exception("model not initialized!") raise ModuleNotFoundError("model not initialized!")
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides.update(kwargs) overrides.update(kwargs)
@ -160,6 +163,51 @@ class YOLO:
validator = self.ValidatorClass(args=args) validator = self.ValidatorClass(args=args)
validator(model=self.model) validator(model=self.model)
@smart_inference_mode()
def export(self, format='', save_dir='', **kwargs):
"""
Export model.
Args:
format (str): Export format
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in the docs
"""
overrides = self.overrides.copy()
overrides.update(kwargs)
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
args.task = self.task
args.format = format
file = self.ckpt or Path(Path(self.cfg).name)
if save_dir:
file = Path(save_dir) / file.name
file.parent.mkdir(parents=True, exist_ok=True)
export_model(
model=self.model,
file=file,
data=args.data, # 'dataset.yaml path'
imgsz=args.imgsz or (640, 640), # image (height, width)
batch_size=1, # batch size
device=args.device, # cuda device, i.e. 0 or 0,1,2,3 or cpu
format=args.format, # include formats
half=args.half or False, # FP16 half-precision export
keras=args.keras or False, # use Keras
optimize=args.optimize or False, # TorchScript: optimize for mobile
int8=args.int8 or False, # CoreML/TF INT8 quantization
dynamic=args.dynamic or False, # ONNX/TF/TensorRT: dynamic axes
opset=args.opset or 17, # ONNX: opset version
verbose=False, # TensorRT: verbose log
workspace=args.workspace or 4, # TensorRT: workspace size (GB)
nms=False, # TF: add NMS to model
agnostic_nms=False, # TF: add agnostic NMS to model
topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold
conf_thres=0.25, # TF.js NMS: confidence threshold
)
def train(self, **kwargs): def train(self, **kwargs):
""" """
Trains the model on given dataset. Trains the model on given dataset.
@ -178,7 +226,7 @@ class YOLO:
overrides["task"] = self.task overrides["task"] = self.task
overrides["mode"] = "train" overrides["mode"] = "train"
if not overrides.get("data"): if not overrides.get("data"):
raise AttributeError("dataset not provided! Please check if you have defined `data` in you configs") raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
self.trainer = self.TrainerClass(overrides=overrides) self.trainer = self.TrainerClass(overrides=overrides)
self.trainer.model = self.trainer.load_model(weights=self.ckpt, self.trainer.model = self.trainer.load_model(weights=self.ckpt,
@ -189,11 +237,11 @@ class YOLO:
def resume(self, task=None, model=None): def resume(self, task=None, model=None):
""" """
Resume a training task. Requires either `task` or `model`. `model` takes the higher precederence. Resume a training task. Requires either `task` or `model`. `model` takes the higher precedence.
Args: Args:
task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified. task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified.
model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed. model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed.
If `model` is speficied If `model` is specified
""" """
if task: if task:
if task.lower() not in MODEL_MAP: if task.lower() not in MODEL_MAP:

@ -1,6 +1,6 @@
# predictor engine by Ultralytics # predictor engine by Ultralytics
""" """
Run prection on images, videos, directories, globs, YouTube, webcam, streams, etc. Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
Usage - sources: Usage - sources:
$ yolo task=... mode=predict model=s.pt --source 0 # webcam $ yolo task=... mode=predict model=s.pt --source 0 # webcam
img.jpg # image img.jpg # image
@ -13,15 +13,15 @@ Usage - sources:
'https://youtu.be/Zgi9g1ksQHc' # YouTube 'https://youtu.be/Zgi9g1ksQHc' # YouTube
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
Usage - formats: Usage - formats:
$ yolo task=... mode=predict --weights yolov5s.pt # PyTorch $ yolo task=... mode=predict --weights yolov8n.pt # PyTorch
yolov5s.torchscript # TorchScript yolov8n.torchscript # TorchScript
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s_openvino_model # OpenVINO yolov5s_openvino_model # OpenVINO
yolov5s.engine # TensorRT yolov8n.engine # TensorRT
yolov5s.mlmodel # CoreML (macOS-only) yolov8n.mlmodel # CoreML (macOS-only)
yolov5s_saved_model # TensorFlow SavedModel yolov5s_saved_model # TensorFlow SavedModel
yolov5s.pb # TensorFlow GraphDef yolov8n.pb # TensorFlow GraphDef
yolov5s.tflite # TensorFlow Lite yolov8n.tflite # TensorFlow Lite
yolov5s_edgetpu.tflite # TensorFlow Edge TPU yolov5s_edgetpu.tflite # TensorFlow Edge TPU
yolov5s_paddle_model # PaddlePaddle yolov5s_paddle_model # PaddlePaddle
""" """
@ -31,16 +31,14 @@ from pathlib import Path
import cv2 import cv2
from ultralytics.nn.autobackend import AutoBackend from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS, check_dataset, check_dataset_yaml from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import LOGGER, ROOT, colorstr, ops from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops
from ultralytics.yolo.utils.checks import check_file, check_imshow from ultralytics.yolo.utils.checks import check_file, check_imshow
from ultralytics.yolo.utils.configs import get_config
from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import check_imgsz, select_device, smart_inference_mode from ultralytics.yolo.utils.torch_utils import check_imgsz, select_device, smart_inference_mode
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
class BasePredictor: class BasePredictor:

@ -23,16 +23,14 @@ from tqdm import tqdm
import ultralytics.yolo.utils as utils import ultralytics.yolo.utils as utils
import ultralytics.yolo.utils.callbacks as callbacks import ultralytics.yolo.utils.callbacks as callbacks
from ultralytics import __version__ from ultralytics import __version__
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, RANK, ROOT, TQDM_BAR_FORMAT, colorstr from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT, colorstr
from ultralytics.yolo.utils.checks import check_file, print_args from ultralytics.yolo.utils.checks import check_file, print_args
from ultralytics.yolo.utils.configs import get_config
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml from ultralytics.yolo.utils.files import get_latest_run, increment_path, yaml_save
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
class BaseTrainer: class BaseTrainer:
@ -53,8 +51,7 @@ class BaseTrainer:
self.wdir = self.save_dir / 'weights' # weights dir self.wdir = self.save_dir / 'weights' # weights dir
if RANK in {-1, 0}: if RANK in {-1, 0}:
self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.wdir.mkdir(parents=True, exist_ok=True) # make dir
# Save run settings yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # save run args
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.batch_size = self.args.batch_size self.batch_size = self.args.batch_size
@ -452,7 +449,8 @@ class BaseTrainer:
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
self.ema.updates = ckpt['updates'] self.ema.updates = ckpt['updates']
if self.args.resume: if self.args.resume:
assert start_epoch > 0, f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ assert start_epoch > 0, \
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'" f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
LOGGER.info( LOGGER.info(
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs') f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')

@ -66,7 +66,7 @@ class BaseValidator:
self.args.batch_size = model.batch_size self.args.batch_size = model.batch_size
else: else:
self.device = model.device self.device = model.device
if not (pt or jit): if not pt and not jit:
self.args.batch_size = 1 # export.py models default to batch-size 1 self.args.batch_size = 1 # export.py models default to batch-size 1
self.logger.info( self.logger.info(
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
@ -75,8 +75,8 @@ class BaseValidator:
data = check_dataset_yaml(self.args.data) data = check_dataset_yaml(self.args.data)
else: else:
data = check_dataset(self.args.data) data = check_dataset(self.args.data)
self.dataloader = self.get_dataloader(data.get("val") or data.set("test"), self.dataloader = self.dataloader or \
self.args.batch_size) if not self.dataloader else self.dataloader self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)
model.eval() model.eval()
@ -139,7 +139,7 @@ class BaseValidator:
def postprocess(self, preds): def postprocess(self, preds):
return preds return preds
def init_metrics(self): def init_metrics(self, model):
pass pass
def update_metrics(self, preds, batch): def update_metrics(self, preds, batch):

@ -1,4 +1,5 @@
import contextlib import contextlib
import inspect
import logging.config import logging.config
import os import os
import platform import platform
@ -13,6 +14,7 @@ import pandas as pd
# Constants # Constants
FILE = Path(__file__).resolve() FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLO ROOT = FILE.parents[2] # YOLO
DEFAULT_CONFIG = ROOT / "yolo/configs/default.yaml"
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
@ -98,6 +100,12 @@ def is_writeable(dir, test=False):
return False return False
def get_default_args(func):
# Get func() default arguments
signature = inspect.signature(func)
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'): def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
# Return path of user configuration directory. Prefer environment variable if exists. Make dir if required. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
env = os.getenv(env_var) env = os.getenv(env_var)

@ -13,6 +13,7 @@ import torch
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis, from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
is_docker, is_notebook) is_docker, is_notebook)
from ultralytics.yolo.utils.ops import make_divisible
def is_ascii(s=''): def is_ascii(s=''):
@ -21,6 +22,18 @@ def is_ascii(s=''):
return len(s.encode().decode('ascii', 'ignore')) == len(s) return len(s.encode().decode('ascii', 'ignore')) == len(s)
def check_imgsz(imgsz, s=32, floor=0):
# Verify image size is a multiple of stride s in each dimension
if isinstance(imgsz, int): # integer i.e. img_size=640
new_size = max(make_divisible(imgsz, int(s)), floor)
else: # list i.e. img_size=[640, 480]
imgsz = list(imgsz) # convert to list if tuple
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
if new_size != imgsz:
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
return new_size
def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False): def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False):
# Check version vs. required version # Check version vs. required version
current, minimum = (pkg.parse_version(x) for x in (current, minimum)) current, minimum = (pkg.parse_version(x) for x in (current, minimum))
@ -93,7 +106,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta
LOGGER.warning(f'{prefix}{e}') LOGGER.warning(f'{prefix}{e}')
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''): def check_suffix(file='yolov8n.pt', suffix=('.pt',), msg=''):
# Check file(s) for acceptable suffix # Check file(s) for acceptable suffix
if file and suffix: if file and suffix:
if isinstance(suffix, str): if isinstance(suffix, str):

@ -49,7 +49,7 @@ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc. # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
def github_assets(repository, version='latest'): def github_assets(repository, version='latest'):
# Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...]) # Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov5m.pt', ...])
if version != 'latest': if version != 'latest':
version = f'tags/{version}' # i.e. tags/v6.2 version = f'tags/{version}' # i.e. tags/v6.2
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api

@ -1,6 +1,7 @@
import contextlib import contextlib
import glob import glob
import os import os
import urllib
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from zipfile import ZipFile from zipfile import ZipFile
@ -43,7 +44,7 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
return path return path
def save_yaml(file='data.yaml', data=None): def yaml_save(file='data.yaml', data=None):
# Single-line safe yaml saving # Single-line safe yaml saving
with open(file, 'w') as f: with open(file, 'w') as f:
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False) yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
@ -52,7 +53,7 @@ def save_yaml(file='data.yaml', data=None):
def yaml_load(file='data.yaml'): def yaml_load(file='data.yaml'):
# Single-line safe yaml loading # Single-line safe yaml loading
with open(file, errors='ignore') as f: with open(file, errors='ignore') as f:
return yaml.safe_load(f) return {**yaml.safe_load(f), 'yaml_file': file} # add YAML filename to dict and return
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')): def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
@ -77,6 +78,24 @@ def file_date(path=__file__):
return f'{t.year}-{t.month}-{t.day}' return f'{t.year}-{t.month}-{t.day}'
def file_size(path):
# Return file/dir size (MB)
mb = 1 << 20 # bytes to MiB (1024 ** 2)
path = Path(path)
if path.is_file():
return path.stat().st_size / mb
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
else:
return 0.0
def url2file(url):
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
def get_latest_run(search_dir='.'): def get_latest_run(search_dir='.'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from) # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)

@ -135,7 +135,7 @@ def non_max_suppression(
for xi, x in enumerate(prediction): # image index, image inference for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints # Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x.T[xc[xi]] # confidence x = x.transpose(0, -1)[xc[xi]] # confidence
# Cat apriori labels if autolabelling # Cat apriori labels if autolabelling
if labels and len(labels[xi]): if labels and len(labels[xi]):

@ -135,8 +135,8 @@ def model_info(model, verbose=False, imgsz=640):
flops = get_flops(model, imgsz) flops = get_flops(model, imgsz)
fs = f', {flops:.1f} GFLOPs' if flops else '' fs = f', {flops:.1f} GFLOPs' if flops else ''
name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model' m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
LOGGER.info(f"{name} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") LOGGER.info(f"{m} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
def get_num_params(model): def get_num_params(model):

@ -6,4 +6,4 @@ ROOT = Path(__file__).parents[0] # yolov8 ROOT
__all__ = ["classify", "segment", "detect"] __all__ = ["classify", "segment", "detect"]
from ultralytics.yolo.utils.configs import hydra_patch # noqa (patch hydra cli) from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra cli)

@ -55,7 +55,7 @@ class ClassificationPredictor(BasePredictor):
def predict(cfg): def predict(cfg):
cfg.model = cfg.model or "squeezenet1_0" cfg.model = cfg.model or "squeezenet1_0"
sz = cfg.imgsz sz = cfg.imgsz
if type(sz) != int: # recieved listConfig if type(sz) != int: # received listConfig
cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand
else: else:
cfg.imgsz = [sz, sz] cfg.imgsz = [sz, sz]

@ -4,7 +4,8 @@ import torch
from ultralytics.nn.tasks import ClassificationModel, get_model from ultralytics.nn.tasks import ClassificationModel, get_model
from ultralytics.yolo import v8 from ultralytics.yolo import v8
from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import DEFAULT_CONFIG
class ClassificationTrainer(BaseTrainer): class ClassificationTrainer(BaseTrainer):

@ -85,7 +85,7 @@ class DetectionPredictor(BasePredictor):
def predict(cfg): def predict(cfg):
cfg.model = cfg.model or "n.pt" cfg.model = cfg.model or "n.pt"
sz = cfg.imgsz sz = cfg.imgsz
if type(sz) != int: # recieved listConfig if type(sz) != int: # received listConfig
cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand
else: else:
cfg.imgsz = [sz, sz] cfg.imgsz = [sz, sz]

@ -6,8 +6,8 @@ from ultralytics.nn.tasks import DetectionModel
from ultralytics.yolo import v8 from ultralytics.yolo import v8
from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import colorstr from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr
from ultralytics.yolo.utils.loss import BboxLoss from ultralytics.yolo.utils.loss import BboxLoss
from ultralytics.yolo.utils.ops import xywh2xyxy from ultralytics.yolo.utils.ops import xywh2xyxy
from ultralytics.yolo.utils.plotting import plot_images, plot_results from ultralytics.yolo.utils.plotting import plot_images, plot_results
@ -185,7 +185,7 @@ class Loss:
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg): def train(cfg):
cfg.model = cfg.model or "models/yolov8n.yaml" cfg.model = cfg.model or "yolov8n.yaml"
cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
# cfg.imgsz = 160 # cfg.imgsz = 160
# cfg.epochs = 5 # cfg.epochs = 5

@ -98,7 +98,7 @@ class SegmentationPredictor(DetectionPredictor):
def predict(cfg): def predict(cfg):
cfg.model = cfg.model or "n.pt" cfg.model = cfg.model or "n.pt"
sz = cfg.imgsz sz = cfg.imgsz
if type(sz) != int: # recieved listConfig if type(sz) != int: # received listConfig
cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand
else: else:
cfg.imgsz = [sz, sz] cfg.imgsz = [sz, sz]

@ -12,11 +12,9 @@ from ultralytics.yolo.utils.plotting import plot_images, plot_results
from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
from ultralytics.yolo.utils.torch_utils import de_parallel from ultralytics.yolo.utils.torch_utils import de_parallel
from ..detect import DetectionTrainer
# BaseTrainer python usage # BaseTrainer python usage
class SegmentationTrainer(DetectionTrainer): class SegmentationTrainer(v8.detect.DetectionTrainer):
def load_model(self, model_cfg=None, weights=None, verbose=True): def load_model(self, model_cfg=None, weights=None, verbose=True):
model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"], verbose=verbose) model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"], verbose=verbose)
@ -174,7 +172,7 @@ class SegLoss:
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg): def train(cfg):
cfg.model = cfg.model or "models/yolov8n-seg.yaml" cfg.model = cfg.model or "yolov8n-seg.yaml"
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
trainer = SegmentationTrainer(cfg) trainer = SegmentationTrainer(cfg)
trainer.train() trainer.train()

Loading…
Cancel
Save