diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 77ec9dd..1d6b6e0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -92,12 +92,12 @@ jobs: run: | yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=1 imgsz=64 yolo task=detect mode=val model=runs/detect/train/weights/last.pt imgsz=64 - - name: Test segmentation # TODO: segmentation CI + - name: Test segmentation shell: bash # for Windows compatibility run: | - # yolo task=segment mode=train model=yolov8n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64 - # yolo task=segment mode=val model=runs/segment/train/weights/last.pt data=coco128-seg.yaml imgsz=64 - - name: Test classification # TODO: change to exp3 on Segmentation CI update + yolo task=segment mode=train model=yolov8n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64 + yolo task=segment mode=val model=runs/segment/train/weights/last.pt data=coco128-seg.yaml imgsz=64 + - name: Test classification shell: bash # for Windows compatibility run: | yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32 diff --git a/tests/assets/dummy_model.yaml b/tests/assets/dummy_model.yaml deleted file mode 100644 index 5154339..0000000 --- a/tests/assets/dummy_model.yaml +++ /dev/null @@ -1,49 +0,0 @@ -# Ultralytics, GPL-3.0 license - -# Parameters -nc: 80 # number of classes -depth_multiple: 0.33 # model depth multiple -width_multiple: 0.50 # layer channel multiple -anchors: - - [10,13, 16,30, 33,23] # P3/8 - - [30,61, 62,45, 59,119] # P4/16 - - [116,90, 156,198, 373,326] # P5/32 - -# YOLOv5 v6.0 backbone -backbone: - # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 - [-1, 3, C3, [1024]], - [-1, 1, SPPF, [1024, 5]], # 9 - ] - -# YOLOv5 v6.0 head -head: - [[-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 13 - - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [256, False]], # 17 (P3/8-small) - - [-1, 1, Conv, [256, 3, 2]], - [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, C3, [512, False]], # 20 (P4/16-medium) - - [-1, 1, Conv, [512, 3, 2]], - [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, C3, [1024, False]], # 23 (P5/32-large) - - [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) - ] - diff --git a/tests/check_flops.py b/tests/check_flops.py index 1dd0604..031d19d 100644 --- a/tests/check_flops.py +++ b/tests/check_flops.py @@ -1,64 +1,16 @@ -import torch - from ultralytics import YOLO -from ultralytics.nn.modules import Detect, Segment - - -def export_onnx(model, file): - # YOLOv5 ONNX export - import onnx - im = torch.zeros(1, 3, 640, 640) - model.eval() - model(im, profile=True) - for k, m in model.named_modules(): - if isinstance(m, (Detect, Segment)): - m.export = True - - torch.onnx.export( - model, - im, - file, - verbose=False, - opset_version=12, - do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False - input_names=['images']) - - # Checks - model_onnx = onnx.load(file) # load onnx model - onnx.checker.check_model(model_onnx) # check onnx model - - # Metadata - d = {'stride': int(max(model.stride)), 'names': model.names} - for k, v in d.items(): - meta = model_onnx.metadata_props.add() - meta.key, meta.value = k, str(v) - onnx.save(model_onnx, file) - if __name__ == "__main__": - model = YOLO() - print("yolov8n") - model.new("yolov8n.yaml") - print("yolov8n-seg") - model.new("yolov8n-seg.yaml") - print("yolov8s") - model.new("yolov8s.yaml") - # export_onnx(model.model, "yolov8s.onnx") - print("yolov8s-seg") - model.new("yolov8s-seg.yaml") - # export_onnx(model.model, "yolov8s-seg.onnx") - print("yolov8m") - model.new("yolov8m.yaml") - print("yolov8m-seg") - model.new("yolov8m-seg.yaml") - print("yolov8l") - model.new("yolov8l.yaml") - print("yolov8l-seg") - model.new("yolov8l-seg.yaml") - print("yolov8x") - model.new("yolov8x.yaml") - print("yolov8x-seg") - model.new("yolov8x-seg.yaml") + YOLO.new("yolov8n.yaml") + YOLO.new("yolov8n-seg.yaml") + YOLO.new("yolov8s.yaml") + YOLO.new("yolov8s-seg.yaml") + YOLO.new("yolov8m.yaml") + YOLO.new("yolov8m-seg.yaml") + YOLO.new("yolov8l.yaml") + YOLO.new("yolov8l-seg.yaml") + YOLO.new("yolov8x.yaml") + YOLO.new("yolov8x-seg.yaml") # n vs n-seg: 8.9GFLOPs vs 12.8GFLOPs, 3.16M vs 3.6M. ch[0] // 4 (11.9GFLOPs, 3.39M) # s vs s-seg: 28.8GFLOPs vs 44.4GFLOPs, 11.1M vs 12.9M. ch[0] // 4 (39.5GFLOPs, 11.7M) diff --git a/tests/data/dataloader/yolodetection.py b/tests/data/dataloader/yolodetection.py index e30ea61..eb03ac3 100644 --- a/tests/data/dataloader/yolodetection.py +++ b/tests/data/dataloader/yolodetection.py @@ -2,11 +2,9 @@ import cv2 import hydra from ultralytics.yolo.data import build_dataloader -from ultralytics.yolo.utils import ROOT +from ultralytics.yolo.utils import DEFAULT_CONFIG from ultralytics.yolo.utils.plotting import plot_images -DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" - class Colors: # Ultralytics color palette https://ultralytics.com/ diff --git a/tests/data/dataloader/yolosegment.py b/tests/data/dataloader/yolosegment.py index f8cca68..8daf406 100644 --- a/tests/data/dataloader/yolosegment.py +++ b/tests/data/dataloader/yolosegment.py @@ -2,11 +2,9 @@ import cv2 import hydra from ultralytics.yolo.data import build_dataloader -from ultralytics.yolo.utils import ROOT +from ultralytics.yolo.utils import DEFAULT_CONFIG from ultralytics.yolo.utils.plotting import plot_images -DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" - class Colors: # Ultralytics color palette https://ultralytics.com/ diff --git a/tests/functional/test_loaders.py b/tests/functional/test_loaders.py index ea481c3..84b0d76 100644 --- a/tests/functional/test_loaders.py +++ b/tests/functional/test_loaders.py @@ -3,11 +3,11 @@ from ultralytics.yolo.utils.checks import check_yaml def test_model_parser(): - cfg = check_yaml("../assets/dummy_model.yaml") # check YAML + cfg = check_yaml("yolov8n.yaml") # check YAML # Create model model = DetectionModel(cfg) - print(model) + model.info() ''' # Options if opt.line_profile: # profile layer by layer diff --git a/tests/test_model.py b/tests/test_model.py index 91d143d..119f32d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -62,6 +62,35 @@ def test_model_train_pretrained(): model(img) +def test_exports(): + """ + Format Argument Suffix CPU GPU + 0 PyTorch - .pt True True + 1 TorchScript torchscript .torchscript True True + 2 ONNX onnx .onnx True True + 3 OpenVINO openvino _openvino_model True False + 4 TensorRT engine .engine False True + 5 CoreML coreml .mlmodel True False + 6 TensorFlow SavedModel saved_model _saved_model True True + 7 TensorFlow GraphDef pb .pb True True + 8 TensorFlow Lite tflite .tflite True False + 9 TensorFlow Edge TPU edgetpu _edgetpu.tflite False False + 10 TensorFlow.js tfjs _web_model False False + 11 PaddlePaddle paddle _paddle_model True True + """ + from ultralytics import YOLO + from ultralytics.yolo.engine.exporter import export_formats + + print(export_formats()) + + model = YOLO.new("yolov8n.yaml") + model.export(format='torchscript') + model.export(format='onnx') + model.export(format='openvino') + model.export(format='coreml') + model.export(format='paddle') + + def test(): test_model_forward() test_model_info() diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index b635d00..91da340 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -19,7 +19,7 @@ from ultralytics.yolo.utils.ops import xywh2xyxy class AutoBackend(nn.Module): # YOLOv5 MultiBackend class for python inference on various backends - def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True): + def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True): # Usage: # PyTorch: weights = *.pt # TorchScript: *.torchscript diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 9ac875f..62aef70 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -6,12 +6,12 @@ import thop import torch import torch.nn as nn import torchvision -import yaml from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus, GhostBottleneck, GhostConv, Segment) from ultralytics.yolo.utils import LOGGER, colorstr +from ultralytics.yolo.utils.files import yaml_load from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts, make_divisible, model_info, scale_img, time_sync) @@ -78,14 +78,9 @@ class BaseModel(nn.Module): class DetectionModel(BaseModel): # YOLOv5 detection model - def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes + def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes super().__init__() - if isinstance(cfg, dict): - self.yaml = cfg # model dict - else: # is *.yaml - self.yaml_file = Path(cfg).name - with open(cfg, encoding='ascii', errors='ignore') as f: - self.yaml = yaml.safe_load(f) # model dict + self.yaml = cfg if isinstance(cfg, dict) else yaml_load(cfg) # cfg dict # Define model ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels @@ -163,7 +158,7 @@ class DetectionModel(BaseModel): class SegmentationModel(DetectionModel): # YOLOv5 segmentation model - def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, verbose=True): + def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True): super().__init__(cfg, ch, nc, verbose) diff --git a/ultralytics/yolo/cli.py b/ultralytics/yolo/cli.py index c0ffeec..3ca907b 100644 --- a/ultralytics/yolo/cli.py +++ b/ultralytics/yolo/cli.py @@ -1,43 +1,48 @@ -import os import shutil +from pathlib import Path import hydra import ultralytics -import ultralytics.yolo.v8 as yolo -from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG +from ultralytics import yolo -from .utils import LOGGER, colorstr +from .utils import DEFAULT_CONFIG, LOGGER, colorstr -@hydra.main(version_base=None, config_path="utils/configs", config_name="default") +@hydra.main(version_base=None, config_path="configs", config_name="default") def cli(cfg): + cwd = Path().cwd() LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}") task, mode = cfg.task.lower(), cfg.mode.lower() if task == "init": # special case - shutil.copy2(DEFAULT_CONFIG, os.getcwd()) + shutil.copy2(DEFAULT_CONFIG, cwd) LOGGER.info(f""" - {colorstr("YOLO :")} configuration saved to {os.getcwd()}/{DEFAULT_CONFIG.name}. + {colorstr("YOLO:")} configuration saved to {cwd / DEFAULT_CONFIG.name}. To run experiments using custom configuration: yolo task='task' mode='mode' --config-name config_file.yaml """) return + elif task == "detect": - module_file = yolo.detect + module = yolo.v8.detect elif task == "segment": - module_file = yolo.segment + module = yolo.v8.segment elif task == "classify": - module_file = yolo.classify + module = yolo.v8.classify + elif task == "export": + func = yolo.trainer.exporter.export_model else: raise SyntaxError("task not recognized. Choices are `'detect', 'segment', 'classify'`") if mode == "train": - module_function = module_file.train + func = module.train elif mode == "val": - module_function = module_file.val + func = module.val elif mode == "predict": - module_function = module_file.predict + func = module.predict + elif mode == "export": + func = yolo.trainer.exporter.export_model else: - raise SyntaxError("mode not recognized. Choices are `'train', 'val', 'predict'`") - module_function(cfg) + raise SyntaxError("mode not recognized. Choices are `'train', 'val', 'predict', 'export'`") + func(cfg) diff --git a/ultralytics/yolo/utils/configs/__init__.py b/ultralytics/yolo/configs/__init__.py similarity index 91% rename from ultralytics/yolo/utils/configs/__init__.py rename to ultralytics/yolo/configs/__init__.py index a1e6cf9..7786668 100644 --- a/ultralytics/yolo/utils/configs/__init__.py +++ b/ultralytics/yolo/configs/__init__.py @@ -3,7 +3,7 @@ from typing import Dict, Union from omegaconf import DictConfig, OmegaConf -from ultralytics.yolo.utils.configs.hydra_patch import check_config_mismatch +from ultralytics.yolo.configs.hydra_patch import check_config_mismatch def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}): diff --git a/ultralytics/yolo/utils/configs/default.yaml b/ultralytics/yolo/configs/default.yaml similarity index 86% rename from ultralytics/yolo/utils/configs/default.yaml rename to ultralytics/yolo/configs/default.yaml index 3865870..9d24a09 100644 --- a/ultralytics/yolo/utils/configs/default.yaml +++ b/ultralytics/yolo/configs/default.yaml @@ -44,11 +44,11 @@ save_hybrid: False conf_thres: 0.001 iou_thres: 0.7 max_det: 300 -half: True +half: False dnn: False # use OpenCV DNN for ONNX inference plots: True -# Prediction settings: +# Prediction settings -------------------------------------------------------------------------------------------------- source: "ultralytics/assets/" view_img: False save_txt: False @@ -64,6 +64,15 @@ augment: False agnostic_nms: False # class-agnostic NMS retina_masks: False +# Export settings ------------------------------------------------------------------------------------------------------ +keras: False # use Keras +optimize: False # TorchScript: optimize for mobile +int8: False # CoreML/TF INT8 quantization +dynamic: False # ONNX/TF/TensorRT: dynamic axes +simplify: False # ONNX: simplify model +opset: 17 # ONNX: opset version +workspace: 4 # TensorRT: workspace size (GB) + # Hyperparameters ------------------------------------------------------------------------------------------------------ lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) @@ -93,7 +102,7 @@ mixup: 0.0 # image mixup (probability) copy_paste: 0.0 # segment copy-paste (probability) # For debugging. Don't change -v5loader: True +v5loader: False # Hydra configs -------------------------------------------------------------------------------------------------------- hydra: diff --git a/ultralytics/yolo/utils/configs/hydra_patch.py b/ultralytics/yolo/configs/hydra_patch.py similarity index 89% rename from ultralytics/yolo/utils/configs/hydra_patch.py rename to ultralytics/yolo/configs/hydra_patch.py index d381697..18075a0 100644 --- a/ultralytics/yolo/utils/configs/hydra_patch.py +++ b/ultralytics/yolo/configs/hydra_patch.py @@ -4,8 +4,8 @@ from textwrap import dedent import hydra from hydra.errors import ConfigCompositionException -from omegaconf import OmegaConf, open_dict -from omegaconf.errors import ConfigAttributeError, ConfigKeyError, OmegaConfBaseException +from omegaconf import OmegaConf, open_dict # noqa +from omegaconf.errors import ConfigAttributeError, ConfigKeyError, OmegaConfBaseException # noqa from ultralytics.yolo.utils import LOGGER, colorstr @@ -16,8 +16,7 @@ def override_config(overrides, cfg): for override in overrides: if override.package is not None: raise ConfigCompositionException(f"Override {override.input_line} looks like a config group" - f" override, but config group '{override.key_or_group}' does not" - " exist.") + f" override, but config group '{override.key_or_group}' does not exist.") key = override.key_or_group value = override.value() @@ -37,7 +36,7 @@ def override_config(overrides, cfg): if last_dot == -1: del cfg[key] else: - node = OmegaConf.select(cfg, key[0:last_dot]) + node = OmegaConf.select(cfg, key[:last_dot]) del node[key[last_dot + 1:]] elif override.is_add(): @@ -65,10 +64,7 @@ def override_config(overrides, cfg): def check_config_mismatch(overrides, cfg): - mismatched = [] - for option in overrides: - if option not in cfg and 'hydra.' not in option: - mismatched.append(option) + mismatched = [option for option in overrides if option not in cfg and 'hydra.' not in option] for option in mismatched: LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, cfg, 3, 0.6)}") diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py index 7cbb059..1800054 100644 --- a/ultralytics/yolo/data/utils.py +++ b/ultralytics/yolo/data/utils.py @@ -192,7 +192,7 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): def check_dataset_yaml(data, autodownload=True): # Download, check and/or unzip dataset if not found locally data = check_file(data) - DATASETS_DIR = Path.cwd() / "../datasets" # TODO: handle global dataset dir + DATASETS_DIR = (Path.cwd() / "../datasets").resolve() # TODO: handle global dataset dir # Download (optional) extract_dir = '' if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)): diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index 16dbcd8..d5dd82a 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -1,4 +1,77 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit + +Format | `export.py --include` | Model +--- | --- | --- +PyTorch | - | yolov8n.pt +TorchScript | `torchscript` | yolov8n.torchscript +ONNX | `onnx` | yolov8n.onnx +OpenVINO | `openvino` | yolov5s_openvino_model/ +TensorRT | `engine` | yolov8n.engine +CoreML | `coreml` | yolov8n.mlmodel +TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/ +TensorFlow GraphDef | `pb` | yolov8n.pb +TensorFlow Lite | `tflite` | yolov8n.tflite +TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite +TensorFlow.js | `tfjs` | yolov5s_web_model/ +PaddlePaddle | `paddle` | yolov5s_paddle_model/ + +Requirements: + $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU + $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU + +Usage: + $ python export.py --weights yolov8n.pt --include torchscript onnx openvino engine coreml tflite ... + +Inference: + $ python detect.py --weights yolov8n.pt # PyTorch + yolov8n.torchscript # TorchScript + yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn + yolov5s_openvino_model # OpenVINO + yolov8n.engine # TensorRT + yolov8n.mlmodel # CoreML (macOS-only) + yolov5s_saved_model # TensorFlow SavedModel + yolov8n.pb # TensorFlow GraphDef + yolov8n.tflite # TensorFlow Lite + yolov5s_edgetpu.tflite # TensorFlow Edge TPU + yolov5s_paddle_model # PaddlePaddle + +TensorFlow.js: + $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example + $ npm install + $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model + $ npm start + + +from ultralytics import YOLO +model = YOLO().new('yolov8n.yaml') +results = model.export(format='onnx') +""" +import contextlib +import json +import os +import platform +import re +import subprocess +import time +import warnings +from copy import deepcopy +from pathlib import Path + import pandas as pd +import torch +from torch.utils.mobile_optimizer import optimize_for_mobile + +from ultralytics.nn.modules import Detect, Segment +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel +from ultralytics.yolo.utils import LOGGER, ROOT, colorstr, get_default_args +from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version +from ultralytics.yolo.utils.files import file_size, yaml_save +from ultralytics.yolo.utils.ops import Profile +from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode + +MACOS = platform.system() == 'Darwin' # macOS environment def export_formats(): @@ -17,3 +90,519 @@ def export_formats(): ['TensorFlow.js', 'tfjs', '_web_model', False, False], ['PaddlePaddle', 'paddle', '_paddle_model', True, True],] return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) + + +def try_export(inner_func): + # YOLOv5 export decorator, i..e @try_export + inner_args = get_default_args(inner_func) + + def outer_func(*args, **kwargs): + prefix = inner_args['prefix'] + try: + with Profile() as dt: + f, model = inner_func(*args, **kwargs) + LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)') + return f, model + except Exception as e: + LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}') + return None, None + + return outer_func + + +@try_export +def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): + # YOLOv5 TorchScript model export + LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...') + f = file.with_suffix('.torchscript') + + ts = torch.jit.trace(model, im, strict=False) + d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names} + extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap() + if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html + optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files) + else: + ts.save(str(f), _extra_files=extra_files) + return f, None + + +@try_export +def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')): + # YOLOv5 ONNX export + check_requirements('onnx>=1.12.0') + import onnx # noqa + + LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') + f = file.with_suffix('.onnx') + + output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0'] + if dynamic: + dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) + if isinstance(model, SegmentationModel): + dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) + dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160) + elif isinstance(model, DetectionModel): + dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) + + torch.onnx.export( + model.cpu() if dynamic else model, # --dynamic only compatible with cpu + im.cpu() if dynamic else im, + f, + verbose=False, + opset_version=opset, + do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False + input_names=['images'], + output_names=output_names, + dynamic_axes=dynamic or None) + + # Checks + model_onnx = onnx.load(f) # load onnx model + onnx.checker.check_model(model_onnx) # check onnx model + + # Metadata + d = {'stride': int(max(model.stride)), 'names': model.names} + for k, v in d.items(): + meta = model_onnx.metadata_props.add() + meta.key, meta.value = k, str(v) + onnx.save(model_onnx, f) + + # Simplify + if simplify: + try: + cuda = torch.cuda.is_available() + check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1')) + import onnxsim + + LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') + model_onnx, check = onnxsim.simplify(model_onnx) + assert check, 'assert check failed' + onnx.save(model_onnx, f) + except Exception as e: + LOGGER.info(f'{prefix} simplifier failure: {e}') + return f, model_onnx + + +@try_export +def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')): + # YOLOv5 OpenVINO export + check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/ + import openvino.inference_engine as ie # noqa + + LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...') + f = str(file).replace('.pt', f'_openvino_model{os.sep}') + + cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}" + subprocess.run(cmd.split(), check=True, env=os.environ) # export + yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml + return f, None + + +@try_export +def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')): + # YOLOv5 Paddle export + check_requirements(('paddlepaddle', 'x2paddle')) + import x2paddle # noqa + from x2paddle.convert import pytorch2paddle # noqa + + LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...') + f = str(file).replace('.pt', f'_paddle_model{os.sep}') + + pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im]) # export + yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml + return f, None + + +@try_export +def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')): + # YOLOv5 CoreML export + check_requirements('coremltools') + import coremltools as ct # noqa + + LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') + f = file.with_suffix('.mlmodel') + + ts = torch.jit.trace(model, im, strict=False) # TorchScript model + ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])]) + bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None) + if bits < 32: + if MACOS: # quantization only supported on macOS + ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) + else: + LOGGER.info(f'{prefix} quantization only supported on macOS, skipping...') + ct_model.save(f) + return f, ct_model + + +@try_export +def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): + # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt + assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' + try: + import tensorrt as trt + except Exception: + if platform.system() == 'Linux': + check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com') + import tensorrt as trt + + if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 + grid = model.model[-1].anchor_grid + model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid] + export_onnx(model, im, file, 12, dynamic, simplify) # opset 12 + model.model[-1].anchor_grid = grid + else: # TensorRT >= 8 + check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0 + export_onnx(model, im, file, 12, dynamic, simplify) # opset 12 + onnx = file.with_suffix('.onnx') + + LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') + assert onnx.exists(), f'failed to export ONNX file: {onnx}' + f = file.with_suffix('.engine') # TensorRT engine file + logger = trt.Logger(trt.Logger.INFO) + if verbose: + logger.min_severity = trt.Logger.Severity.VERBOSE + + builder = trt.Builder(logger) + config = builder.create_builder_config() + config.max_workspace_size = workspace * 1 << 30 + # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice + + flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + network = builder.create_network(flag) + parser = trt.OnnxParser(network, logger) + if not parser.parse_from_file(str(onnx)): + raise RuntimeError(f'failed to load ONNX file: {onnx}') + + inputs = [network.get_input(i) for i in range(network.num_inputs)] + outputs = [network.get_output(i) for i in range(network.num_outputs)] + for inp in inputs: + LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}') + for out in outputs: + LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}') + + if dynamic: + if im.shape[0] <= 1: + LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument") + profile = builder.create_optimization_profile() + for inp in inputs: + profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape) + config.add_optimization_profile(profile) + + LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}') + if builder.platform_has_fast_fp16 and half: + config.set_flag(trt.BuilderFlag.FP16) + with builder.build_engine(network, config) as engine, open(f, 'wb') as t: + t.write(engine.serialize()) + return f, None + + +@try_export +def export_saved_model(model, + im, + file, + dynamic, + tf_nms=False, + agnostic_nms=False, + topk_per_class=100, + topk_all=100, + iou_thres=0.45, + conf_thres=0.25, + keras=False, + prefix=colorstr('TensorFlow SavedModel:')): + # YOLOv5 TensorFlow SavedModel export + try: + import tensorflow as tf + except Exception: + check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}") + import tensorflow as tf + from models.tf import TFModel + from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa + + LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') + f = str(file).replace('.pt', '_saved_model') + batch_size, ch, *imgsz = list(im.shape) # BCHW + + tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) + im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow + _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) + inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size) + outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) + keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) + keras_model.trainable = False + keras_model.summary() + if keras: + keras_model.save(f, save_format='tf') + else: + spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype) + m = tf.function(lambda x: keras_model(x)) # full model + m = m.get_concrete_function(spec) + frozen_func = convert_variables_to_constants_v2(m) + tfm = tf.Module() + tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec]) + tfm.__call__(im) + tf.saved_model.save(tfm, + f, + options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version( + tf.__version__, '2.6') else tf.saved_model.SaveOptions()) + return f, keras_model + + +@try_export +def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')): + # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow + import tensorflow as tf # noqa + from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa + + LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') + f = file.with_suffix('.pb') + + m = tf.function(lambda x: keras_model(x)) # full model + m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)) + frozen_func = convert_variables_to_constants_v2(m) + frozen_func.graph.as_graph_def() + tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False) + return f, None + + +@try_export +def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')): + # YOLOv5 TensorFlow Lite export + import tensorflow as tf # noqa + + LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') + batch_size, ch, *imgsz = list(im.shape) # BCHW + f = str(file).replace('.pt', '-fp16.tflite') + + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] + converter.target_spec.supported_types = [tf.float16] + converter.optimizations = [tf.lite.Optimize.DEFAULT] + if int8: + # from models.tf import representative_dataset_gen + # dataset = LoadImages(check_dataset(check_yaml(data))['train'], imgsz=imgsz, auto=False) + # converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.target_spec.supported_types = [] + converter.inference_input_type = tf.uint8 # or tf.int8 + converter.inference_output_type = tf.uint8 # or tf.int8 + converter.experimental_new_quantizer = True + f = str(file).replace('.pt', '-int8.tflite') + if nms or agnostic_nms: + converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS) + + tflite_model = converter.convert() + open(f, "wb").write(tflite_model) + return f, None + + +@try_export +def export_edgetpu(file, prefix=colorstr('Edge TPU:')): + # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/ + cmd = 'edgetpu_compiler --version' + help_url = 'https://coral.ai/docs/edgetpu/compiler/' + assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}' + if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0: + LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}') + sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system + for c in ( + 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', + 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', + 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'): + subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True) + ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1] + + LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...') + f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model + f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model + + cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}" + subprocess.run(cmd.split(), check=True) + return f, None + + +@try_export +def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): + # YOLOv5 TensorFlow.js export + check_requirements('tensorflowjs') + import tensorflowjs as tfjs # noqa + + LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...') + f = str(file).replace('.pt', '_web_model') # js dir + f_pb = file.with_suffix('.pb') # *.pb path + f_json = f'{f}/model.json' # *.json path + + cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \ + f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}' + subprocess.run(cmd.split()) + + json = Path(f_json).read_text() + with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order + subst = re.sub( + r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, ' + r'"Identity.?.?": {"name": "Identity.?.?"}, ' + r'"Identity.?.?": {"name": "Identity.?.?"}, ' + r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, ' + r'"Identity_1": {"name": "Identity_1"}, ' + r'"Identity_2": {"name": "Identity_2"}, ' + r'"Identity_3": {"name": "Identity_3"}}}', json) + j.write(subst) + return f, None + + +def add_tflite_metadata(file, metadata, num_outputs): + # Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata + with contextlib.suppress(ImportError): + # check_requirements('tflite_support') + from tflite_support import flatbuffers # noqa + from tflite_support import metadata as _metadata # noqa + from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa + + tmp_file = Path('/tmp/meta.txt') + with open(tmp_file, 'w') as meta_f: + meta_f.write(str(metadata)) + + model_meta = _metadata_fb.ModelMetadataT() + label_file = _metadata_fb.AssociatedFileT() + label_file.name = tmp_file.name + model_meta.associatedFiles = [label_file] + + subgraph = _metadata_fb.SubGraphMetadataT() + subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()] + subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs + model_meta.subgraphMetadata = [subgraph] + + b = flatbuffers.Builder(0) + b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_buf = b.Output() + + populator = _metadata.MetadataPopulator.with_model_file(file) + populator.load_metadata_buffer(metadata_buf) + populator.load_associated_files([str(tmp_file)]) + populator.populate() + tmp_file.unlink() + + +@smart_inference_mode() +def export_model( + model, # model + file=ROOT / 'yolov8n.pt', + data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' + imgsz=(640, 640), # image (height, width) + batch_size=1, # batch size + device=torch.device('cpu'), # cuda device, i.e. 0 or 0,1,2,3 or cpu + format='onnx', # export format + half=False, # FP16 half-precision export + keras=False, # use Keras + optimize=False, # TorchScript: optimize for mobile + int8=False, # CoreML/TF INT8 quantization + dynamic=False, # ONNX/TF/TensorRT: dynamic axes + simplify=False, # ONNX: simplify model + opset=17, # ONNX: opset version + verbose=False, # TensorRT: verbose log + workspace=4, # TensorRT: workspace size (GB) + nms=False, # TF: add NMS to model + agnostic_nms=False, # TF: add agnostic NMS to model + topk_per_class=100, # TF.js NMS: topk per class to keep + topk_all=100, # TF.js NMS: topk for all classes to keep + iou_thres=0.45, # TF.js NMS: IoU threshold + conf_thres=0.25, # TF.js NMS: confidence threshold +): + t = time.time() + format = format.lower() # to lowercase + fmts = tuple(export_formats()['Argument'][1:]) # available export formats + flags = [x == format for x in fmts] + assert sum(flags), f'ERROR: Invalid format={format}, valid formats are {fmts}' + jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans + + # Load PyTorch model + device = select_device(device) + if half: + assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0' + assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both' + model = deepcopy(model).fuse() # load FP32 model + + # Checks + if isinstance(imgsz, int): + imgsz = [imgsz] + imgsz *= 2 if len(imgsz) == 1 else 1 # expand + if optimize: + assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu' + + # Input + gs = int(max(model.stride)) # grid size (max stride) + imgsz = [check_imgsz(x, gs) for x in imgsz] # verify img_size are gs-multiples + im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection + + # Update model + model.eval() + for k, m in model.named_modules(): + if isinstance(m, (Detect, Segment)): + m.dynamic = dynamic + m.export = True + + for _ in range(2): + y = model(im) # dry runs + if half and not coreml: + im, model = im.half(), model.half() # to FP16 + shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape + metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata + LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") + + # Warnings + warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning + warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant type missing ONNX warning + warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning + + # Exports + f = [''] * len(fmts) # exported filenames + if jit: # TorchScript + f[0], _ = export_torchscript(model, im, file, optimize) + if engine: # TensorRT required before ONNX + f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose) + if onnx or xml: # OpenVINO requires ONNX + f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify) + if xml: # OpenVINO + f[3], _ = export_openvino(file, metadata, half) + if coreml: # CoreML + f[4], _ = export_coreml(model, im, file, int8, half) + if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats + assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.' + assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.' + f[5], s_model = export_saved_model(model.cpu(), + im, + file, + dynamic, + tf_nms=nms or agnostic_nms or tfjs, + agnostic_nms=agnostic_nms or tfjs, + topk_per_class=topk_per_class, + topk_all=topk_all, + iou_thres=iou_thres, + conf_thres=conf_thres, + keras=keras) + if pb or tfjs: # pb prerequisite to tfjs + f[6], _ = export_pb(s_model, file) + if tflite or edgetpu: + f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) + if edgetpu: + f[8], _ = export_edgetpu(file) + add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs)) + if tfjs: + f[9], _ = export_tfjs(file) + if paddle: # PaddlePaddle + f[10], _ = export_paddle(model, im, file, metadata) + + # Finish + f = [str(x) for x in f if x] # filter out '' and None + if any(f): + cls, det, seg = (isinstance(model, x) for x in (ClassificationModel, DetectionModel, SegmentationModel)) # type + det &= not seg # segmentation models inherit from SegmentationModel(DetectionModel) + dir = Path('segment' if seg else 'classify' if cls else '') + h = '--half' if half else '' # --half FP16 inference arg + s = "# WARNING ⚠️ ClassificationModel not yet supported for PyTorch Hub AutoShape inference" if cls else \ + "# WARNING ⚠️ SegmentationModel not yet supported for PyTorch Hub AutoShape inference" if seg else '' + LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' + f"\nResults saved to {colorstr('bold', file.parent.resolve())}" + f"\nDetect: python {dir / 'predict.py'} --weights {f[-1]} {h}" + f"\nValidate: python {dir / 'val.py'} --weights {f[-1]} {h}" + f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}') {s}" + f"\nVisualize: https://netron.app") + return f # return list of exported files/dirs diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 7dfedf9..05b2d5b 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -1,13 +1,13 @@ +from pathlib import Path + import torch -import yaml from ultralytics import yolo # noqa required for python usage from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights -# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml -from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG -from ultralytics.yolo.utils import HELP_MSG, LOGGER +from ultralytics.yolo.configs import get_config +from ultralytics.yolo.engine.exporter import export_model +from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER from ultralytics.yolo.utils.checks import check_yaml -from ultralytics.yolo.utils.configs import get_config from ultralytics.yolo.utils.files import yaml_load from ultralytics.yolo.utils.torch_utils import smart_inference_mode @@ -36,7 +36,7 @@ class YOLO: type (str): Type/version of models to use """ if init_key != YOLO.__init_key: - raise Exception(HELP_MSG) + raise SyntaxError(HELP_MSG) self.type = type self.ModelClass = None @@ -46,7 +46,8 @@ class YOLO: self.model = None self.trainer = None self.task = None - self.ckpt = None + self.ckpt = None # if loaded from *.pt + self.cfg = None # if loaded from *.yaml self.overrides = {} self.init_disabled = False @@ -59,12 +60,12 @@ class YOLO: cfg (str): model configuration file """ cfg = check_yaml(cfg) # check YAML - with open(cfg, encoding='ascii', errors='ignore') as f: - cfg = yaml.safe_load(f) # model dict + cfg_dict = yaml_load(cfg) # model dict obj = cls(init_key=cls.__init_key) - obj.task = obj._guess_task_from_head(cfg["head"][-1][-2]) + obj.task = obj._guess_task_from_head(cfg_dict["head"][-1][-2]) obj.ModelClass, obj.TrainerClass, obj.ValidatorClass, obj.PredictorClass = obj._guess_ops_from_task(obj.task) - obj.model = obj.ModelClass(cfg) # initialize + obj.model = obj.ModelClass(cfg_dict) # initialize + obj.cfg = cfg return obj @@ -116,13 +117,14 @@ class YOLO: LOGGER.info("model not initialized!") self.model.fuse() + @smart_inference_mode() def predict(self, source, **kwargs): """ - Visualize prection. + Visualize prediction. Args: source (str): Accepts all source types accepted by yolo - **kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs + **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in the docs """ overrides = self.overrides.copy() overrides.update(kwargs) @@ -131,7 +133,7 @@ class YOLO: # check size type sz = predictor.args.imgsz - if type(sz) != int: # recieved listConfig + if type(sz) != int: # received listConfig predictor.args.imgsz = [sz[0], sz[0]] if len(sz) == 1 else [sz[0], sz[1]] # expand else: predictor.args.imgsz = [sz, sz] @@ -139,16 +141,17 @@ class YOLO: predictor.setup(model=self.model, source=source) predictor() + @smart_inference_mode() def val(self, data=None, **kwargs): """ Validate a model on a given dataset Args: data (str): The dataset to validate on. Accepts all formats accepted by yolo - kwargs: Any other args accepted by the validators. Too see all args check 'configuration' section in the docs + kwargs: Any other args accepted by the validators. To see all args check 'configuration' section in the docs """ if not self.model: - raise Exception("model not initialized!") + raise ModuleNotFoundError("model not initialized!") overrides = self.overrides.copy() overrides.update(kwargs) @@ -160,6 +163,51 @@ class YOLO: validator = self.ValidatorClass(args=args) validator(model=self.model) + @smart_inference_mode() + def export(self, format='', save_dir='', **kwargs): + """ + Export model. + + Args: + format (str): Export format + **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in the docs + """ + + overrides = self.overrides.copy() + overrides.update(kwargs) + args = get_config(config=DEFAULT_CONFIG, overrides=overrides) + args.task = self.task + args.format = format + + file = self.ckpt or Path(Path(self.cfg).name) + if save_dir: + file = Path(save_dir) / file.name + file.parent.mkdir(parents=True, exist_ok=True) + + export_model( + model=self.model, + file=file, + data=args.data, # 'dataset.yaml path' + imgsz=args.imgsz or (640, 640), # image (height, width) + batch_size=1, # batch size + device=args.device, # cuda device, i.e. 0 or 0,1,2,3 or cpu + format=args.format, # include formats + half=args.half or False, # FP16 half-precision export + keras=args.keras or False, # use Keras + optimize=args.optimize or False, # TorchScript: optimize for mobile + int8=args.int8 or False, # CoreML/TF INT8 quantization + dynamic=args.dynamic or False, # ONNX/TF/TensorRT: dynamic axes + opset=args.opset or 17, # ONNX: opset version + verbose=False, # TensorRT: verbose log + workspace=args.workspace or 4, # TensorRT: workspace size (GB) + nms=False, # TF: add NMS to model + agnostic_nms=False, # TF: add agnostic NMS to model + topk_per_class=100, # TF.js NMS: topk per class to keep + topk_all=100, # TF.js NMS: topk for all classes to keep + iou_thres=0.45, # TF.js NMS: IoU threshold + conf_thres=0.25, # TF.js NMS: confidence threshold + ) + def train(self, **kwargs): """ Trains the model on given dataset. @@ -178,7 +226,7 @@ class YOLO: overrides["task"] = self.task overrides["mode"] = "train" if not overrides.get("data"): - raise AttributeError("dataset not provided! Please check if you have defined `data` in you configs") + raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.") self.trainer = self.TrainerClass(overrides=overrides) self.trainer.model = self.trainer.load_model(weights=self.ckpt, @@ -189,11 +237,11 @@ class YOLO: def resume(self, task=None, model=None): """ - Resume a training task. Requires either `task` or `model`. `model` takes the higher precederence. + Resume a training task. Requires either `task` or `model`. `model` takes the higher precedence. Args: task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified. model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed. - If `model` is speficied + If `model` is specified """ if task: if task.lower() not in MODEL_MAP: diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index bcd939f..641d05f 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -1,6 +1,6 @@ # predictor engine by Ultralytics """ -Run prection on images, videos, directories, globs, YouTube, webcam, streams, etc. +Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc. Usage - sources: $ yolo task=... mode=predict model=s.pt --source 0 # webcam img.jpg # image @@ -13,15 +13,15 @@ Usage - sources: 'https://youtu.be/Zgi9g1ksQHc' # YouTube 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream Usage - formats: - $ yolo task=... mode=predict --weights yolov5s.pt # PyTorch - yolov5s.torchscript # TorchScript - yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn + $ yolo task=... mode=predict --weights yolov8n.pt # PyTorch + yolov8n.torchscript # TorchScript + yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn yolov5s_openvino_model # OpenVINO - yolov5s.engine # TensorRT - yolov5s.mlmodel # CoreML (macOS-only) + yolov8n.engine # TensorRT + yolov8n.mlmodel # CoreML (macOS-only) yolov5s_saved_model # TensorFlow SavedModel - yolov5s.pb # TensorFlow GraphDef - yolov5s.tflite # TensorFlow Lite + yolov8n.pb # TensorFlow GraphDef + yolov8n.tflite # TensorFlow Lite yolov5s_edgetpu.tflite # TensorFlow Edge TPU yolov5s_paddle_model # PaddlePaddle """ @@ -31,16 +31,14 @@ from pathlib import Path import cv2 from ultralytics.nn.autobackend import AutoBackend +from ultralytics.yolo.configs import get_config from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams -from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS, check_dataset, check_dataset_yaml -from ultralytics.yolo.utils import LOGGER, ROOT, colorstr, ops +from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS +from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops from ultralytics.yolo.utils.checks import check_file, check_imshow -from ultralytics.yolo.utils.configs import get_config from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.torch_utils import check_imgsz, select_device, smart_inference_mode -DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" - class BasePredictor: diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 71c2e20..bdb589b 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -23,16 +23,14 @@ from tqdm import tqdm import ultralytics.yolo.utils as utils import ultralytics.yolo.utils.callbacks as callbacks from ultralytics import __version__ +from ultralytics.yolo.configs import get_config from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml -from ultralytics.yolo.utils import LOGGER, RANK, ROOT, TQDM_BAR_FORMAT, colorstr +from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT, colorstr from ultralytics.yolo.utils.checks import check_file, print_args -from ultralytics.yolo.utils.configs import get_config from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command -from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml +from ultralytics.yolo.utils.files import get_latest_run, increment_path, yaml_save from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer -DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" - class BaseTrainer: @@ -53,8 +51,7 @@ class BaseTrainer: self.wdir = self.save_dir / 'weights' # weights dir if RANK in {-1, 0}: self.wdir.mkdir(parents=True, exist_ok=True) # make dir - # Save run settings - save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) + yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # save run args self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths self.batch_size = self.args.batch_size @@ -452,8 +449,9 @@ class BaseTrainer: self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA self.ema.updates = ckpt['updates'] if self.args.resume: - assert start_epoch > 0, f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ - f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'" + assert start_epoch > 0, \ + f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ + f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'" LOGGER.info( f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs') if self.epochs < start_epoch: diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index d0b001a..3be5587 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -66,7 +66,7 @@ class BaseValidator: self.args.batch_size = model.batch_size else: self.device = model.device - if not (pt or jit): + if not pt and not jit: self.args.batch_size = 1 # export.py models default to batch-size 1 self.logger.info( f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') @@ -75,8 +75,8 @@ class BaseValidator: data = check_dataset_yaml(self.args.data) else: data = check_dataset(self.args.data) - self.dataloader = self.get_dataloader(data.get("val") or data.set("test"), - self.args.batch_size) if not self.dataloader else self.dataloader + self.dataloader = self.dataloader or \ + self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size) model.eval() @@ -139,7 +139,7 @@ class BaseValidator: def postprocess(self, preds): return preds - def init_metrics(self): + def init_metrics(self, model): pass def update_metrics(self, preds, batch): diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index 76fc0c6..ecbdfe6 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -1,4 +1,5 @@ import contextlib +import inspect import logging.config import os import platform @@ -13,6 +14,7 @@ import pandas as pd # Constants FILE = Path(__file__).resolve() ROOT = FILE.parents[2] # YOLO +DEFAULT_CONFIG = ROOT / "yolo/configs/default.yaml" RANK = int(os.getenv('RANK', -1)) DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads @@ -98,6 +100,12 @@ def is_writeable(dir, test=False): return False +def get_default_args(func): + # Get func() default arguments + signature = inspect.signature(func) + return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty} + + def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'): # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required. env = os.getenv(env_var) diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index 6e8da07..113e0e9 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -13,6 +13,7 @@ import torch from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis, is_docker, is_notebook) +from ultralytics.yolo.utils.ops import make_divisible def is_ascii(s=''): @@ -21,6 +22,18 @@ def is_ascii(s=''): return len(s.encode().decode('ascii', 'ignore')) == len(s) +def check_imgsz(imgsz, s=32, floor=0): + # Verify image size is a multiple of stride s in each dimension + if isinstance(imgsz, int): # integer i.e. img_size=640 + new_size = max(make_divisible(imgsz, int(s)), floor) + else: # list i.e. img_size=[640, 480] + imgsz = list(imgsz) # convert to list if tuple + new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz] + if new_size != imgsz: + LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}') + return new_size + + def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False): # Check version vs. required version current, minimum = (pkg.parse_version(x) for x in (current, minimum)) @@ -93,7 +106,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta LOGGER.warning(f'{prefix} ❌ {e}') -def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''): +def check_suffix(file='yolov8n.pt', suffix=('.pt',), msg=''): # Check file(s) for acceptable suffix if file and suffix: if isinstance(suffix, str): diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py index 71fa63d..6eb0806 100644 --- a/ultralytics/yolo/utils/downloads.py +++ b/ultralytics/yolo/utils/downloads.py @@ -49,7 +49,7 @@ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'): # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc. def github_assets(repository, version='latest'): - # Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...]) + # Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov5m.pt', ...]) if version != 'latest': version = f'tags/{version}' # i.e. tags/v6.2 response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api diff --git a/ultralytics/yolo/utils/files.py b/ultralytics/yolo/utils/files.py index 0e97491..7d0e5e3 100644 --- a/ultralytics/yolo/utils/files.py +++ b/ultralytics/yolo/utils/files.py @@ -1,6 +1,7 @@ import contextlib import glob import os +import urllib from datetime import datetime from pathlib import Path from zipfile import ZipFile @@ -43,7 +44,7 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False): return path -def save_yaml(file='data.yaml', data=None): +def yaml_save(file='data.yaml', data=None): # Single-line safe yaml saving with open(file, 'w') as f: yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False) @@ -52,7 +53,7 @@ def save_yaml(file='data.yaml', data=None): def yaml_load(file='data.yaml'): # Single-line safe yaml loading with open(file, errors='ignore') as f: - return yaml.safe_load(f) + return {**yaml.safe_load(f), 'yaml_file': file} # add YAML filename to dict and return def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')): @@ -77,6 +78,24 @@ def file_date(path=__file__): return f'{t.year}-{t.month}-{t.day}' +def file_size(path): + # Return file/dir size (MB) + mb = 1 << 20 # bytes to MiB (1024 ** 2) + path = Path(path) + if path.is_file(): + return path.stat().st_size / mb + elif path.is_dir(): + return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb + else: + return 0.0 + + +def url2file(url): + # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt + url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/ + return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth + + def get_latest_run(search_dir='.'): # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py index 75f2975..55efc10 100644 --- a/ultralytics/yolo/utils/ops.py +++ b/ultralytics/yolo/utils/ops.py @@ -135,7 +135,7 @@ def non_max_suppression( for xi, x in enumerate(prediction): # image index, image inference # Apply constraints # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height - x = x.T[xc[xi]] # confidence + x = x.transpose(0, -1)[xc[xi]] # confidence # Cat apriori labels if autolabelling if labels and len(labels[xi]): diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index c53cdb7..e3caa8f 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -135,8 +135,8 @@ def model_info(model, verbose=False, imgsz=640): flops = get_flops(model, imgsz) fs = f', {flops:.1f} GFLOPs' if flops else '' - name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model' - LOGGER.info(f"{name} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") + m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model' + LOGGER.info(f"{m} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") def get_num_params(model): diff --git a/ultralytics/yolo/v8/__init__.py b/ultralytics/yolo/v8/__init__.py index 8efc544..e43e165 100644 --- a/ultralytics/yolo/v8/__init__.py +++ b/ultralytics/yolo/v8/__init__.py @@ -6,4 +6,4 @@ ROOT = Path(__file__).parents[0] # yolov8 ROOT __all__ = ["classify", "segment", "detect"] -from ultralytics.yolo.utils.configs import hydra_patch # noqa (patch hydra cli) +from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra cli) diff --git a/ultralytics/yolo/v8/classify/predict.py b/ultralytics/yolo/v8/classify/predict.py index 9b6e112..fe9b7b5 100644 --- a/ultralytics/yolo/v8/classify/predict.py +++ b/ultralytics/yolo/v8/classify/predict.py @@ -55,7 +55,7 @@ class ClassificationPredictor(BasePredictor): def predict(cfg): cfg.model = cfg.model or "squeezenet1_0" sz = cfg.imgsz - if type(sz) != int: # recieved listConfig + if type(sz) != int: # received listConfig cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand else: cfg.imgsz = [sz, sz] diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 7e95ba1..0a6c71b 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -4,7 +4,8 @@ import torch from ultralytics.nn.tasks import ClassificationModel, get_model from ultralytics.yolo import v8 from ultralytics.yolo.data import build_classification_dataloader -from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer +from ultralytics.yolo.engine.trainer import BaseTrainer +from ultralytics.yolo.utils import DEFAULT_CONFIG class ClassificationTrainer(BaseTrainer): diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py index d537e64..26176c1 100644 --- a/ultralytics/yolo/v8/detect/predict.py +++ b/ultralytics/yolo/v8/detect/predict.py @@ -85,7 +85,7 @@ class DetectionPredictor(BasePredictor): def predict(cfg): cfg.model = cfg.model or "n.pt" sz = cfg.imgsz - if type(sz) != int: # recieved listConfig + if type(sz) != int: # received listConfig cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand else: cfg.imgsz = [sz, sz] diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index 1b7f867..efc1296 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -6,8 +6,8 @@ from ultralytics.nn.tasks import DetectionModel from ultralytics.yolo import v8 from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader -from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer -from ultralytics.yolo.utils import colorstr +from ultralytics.yolo.engine.trainer import BaseTrainer +from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr from ultralytics.yolo.utils.loss import BboxLoss from ultralytics.yolo.utils.ops import xywh2xyxy from ultralytics.yolo.utils.plotting import plot_images, plot_results @@ -185,7 +185,7 @@ class Loss: @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def train(cfg): - cfg.model = cfg.model or "models/yolov8n.yaml" + cfg.model = cfg.model or "yolov8n.yaml" cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist") # cfg.imgsz = 160 # cfg.epochs = 5 diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index 55cb9f5..cc3066b 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -98,7 +98,7 @@ class SegmentationPredictor(DetectionPredictor): def predict(cfg): cfg.model = cfg.model or "n.pt" sz = cfg.imgsz - if type(sz) != int: # recieved listConfig + if type(sz) != int: # received listConfig cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand else: cfg.imgsz = [sz, sz] diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 4809f6e..757a0dd 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -12,11 +12,9 @@ from ultralytics.yolo.utils.plotting import plot_images, plot_results from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors from ultralytics.yolo.utils.torch_utils import de_parallel -from ..detect import DetectionTrainer - # BaseTrainer python usage -class SegmentationTrainer(DetectionTrainer): +class SegmentationTrainer(v8.detect.DetectionTrainer): def load_model(self, model_cfg=None, weights=None, verbose=True): model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"], verbose=verbose) @@ -174,7 +172,7 @@ class SegLoss: @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def train(cfg): - cfg.model = cfg.model or "models/yolov8n-seg.yaml" + cfg.model = cfg.model or "yolov8n-seg.yaml" cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist") trainer = SegmentationTrainer(cfg) trainer.train()