Check PyTorch model status for all `YOLO` methods (#945)

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent fd5be10c66
commit 20fe708f31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,7 +29,7 @@ jobs:
- os: ubuntu-latest
python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8
model: yolov8n
torch: '1.7.0' # min torch version CI https://pypi.org/project/torchvision/
torch: '1.8.0' # min torch version CI https://pypi.org/project/torchvision/
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
@ -48,13 +48,12 @@ jobs:
- name: Install requirements
run: |
python -m pip install --upgrade pip wheel
if [ "${{ matrix.torch }}" == "1.7.0" ]; then
pip install -r requirements.txt torch==1.7.0 torchvision==0.8.1 --extra-index-url https://download.pytorch.org/whl/cpu
if [ "${{ matrix.torch }}" == "1.8.0" ]; then
pip install -e . torch==1.8.0 torchvision==0.9.0 onnx openvino-dev>=2022.3 pytest --extra-index-url https://download.pytorch.org/whl/cpu
else
pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e . onnx openvino-dev>=2022.3 pytest --extra-index-url https://download.pytorch.org/whl/cpu
fi
# pip install ultralytics (production)
pip install -e . pytest
shell: bash # for Windows compatibility
- name: Check environment
run: |

@ -18,7 +18,7 @@ jobs:
steps:
- name: "CLA Assistant"
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I sign the CLA') || github.event_name == 'pull_request_target'
uses: contributor-assistant/github-action@v2.2.1
uses: contributor-assistant/github-action@v2.3.0
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# must be repository secret token

@ -114,8 +114,8 @@ We are still working on several parts of YOLOv8! We aim to have these completed
to par with YOLOv5, including export and inference to all the same formats. We are also writing a YOLOv8 paper which we
will submit to [arxiv.org](https://arxiv.org) once complete.
- [ ] TensorFlow exports
- [ ] DDP resume
- [x] TensorFlow exports
- [x] DDP resume
- [ ] [arxiv.org](https://arxiv.org) paper
</details>
@ -246,8 +246,7 @@ YOLOv8 is available under two different licenses:
## <div align="center">Contact</div>
For YOLOv8 bugs and feature requests please visit [GitHub Issues](https://github.com/ultralytics/ultralytics/issues).
For professional support please [Contact Us](https://ultralytics.com/contact).
For YOLOv8 bug reports and feature requests please visit [GitHub Issues](https://github.com/ultralytics/ultralytics/issues) or the [Ultralytics Community Forum](https://community.ultralytics.com/).
<br>
<div align="center">

@ -101,8 +101,8 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式
我们仍在努力完善 YOLOv8 的几个部分!我们的目标是尽快完成这些工作,使 YOLOv8 的功能设置达到YOLOv5 的水平,包括对所有相同格式的导出和推理。我们还在写一篇 YOLOv8 的论文,一旦完成,我们将提交给 [arxiv.org](https://arxiv.org)。
- [ ] TensorFlow 导出
- [ ] DDP 恢复训练
- [x] TensorFlow 导出
- [x] DDP 恢复训练
- [ ] [arxiv.org](https://arxiv.org) 论文
</details>
@ -214,7 +214,7 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式
## <div align="center">联系我们</div>
若发现 YOLOv8 的 Bug 或有功能需求,请访问 [GitHub 问题](https://github.com/ultralytics/ultralytics/issues)。如需专业支持,请 [联系我们](https://ultralytics.com/contact)
请访问 [GitHub Issues](https://github.com/ultralytics/ultralytics/issues) 或 [Ultralytics Community Forum](https://community.ultralytis.com) 以报告 YOLOv8 错误和请求功能
<br>
<div align="center">

@ -0,0 +1,17 @@
At [Ultralytics](https://ultralytics.com), the security of our users' data and systems is of utmost importance. To ensure the safety and security of our [open-source projects](https://github.com/ultralytics), we have implemented several measures to detect and prevent security vulnerabilities.
[![ultralytics](https://snyk.io/advisor/python/ultralytics/badge.svg)](https://snyk.io/advisor/python/ultralytics)
## Snyk Scanning
We use [Snyk](https://snyk.io/advisor/python/ultralytics) to regularly scan the YOLOv8 repository for vulnerabilities and security issues. Our goal is to identify and remediate any potential threats as soon as possible, to minimize any risks to our users.
## GitHub CodeQL Scanning
In addition to our Snyk scans, we also use GitHub's [CodeQL](https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/about-code-scanning-with-codeql) scans to proactively identify and address security vulnerabilities.
## Reporting Security Issues
If you suspect or discover a security vulnerability in the YOLOv8 repository, please let us know immediately. You can reach out to us directly via our [contact form](https://ultralytics.com/contact) or via [security@ultralytics.com](mailto:security@ultralytics.com). Our security team will investigate and respond as soon as possible.
We appreciate your help in keeping the YOLOv8 repository secure and safe for everyone.

@ -122,3 +122,4 @@ nav:
- Results: reference/results.md
- ultralytics.nn: reference/nn.md
- Operations: reference/ops.md
- Security: SECURITY.md

@ -48,18 +48,18 @@ def test_val_classify():
# Predict checks -------------------------------------------------------------------------------------------------------
def test_predict_detect():
run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
run(f"yolo predict detect model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32")
run(f"yolo predict detect model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape.mov imgsz=32")
run(f"yolo predict detect model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait.mov imgsz=32")
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32")
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32")
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32")
def test_predict_segment():
run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32")
run(f"yolo predict model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32")
def test_predict_classify():
run(f"yolo predict classify model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32")
run(f"yolo predict model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32")
# Export checks --------------------------------------------------------------------------------------------------------

@ -18,7 +18,6 @@ SOURCE = ROOT / 'assets/bus.jpg'
def test_model_forward():
model = YOLO(CFG)
model.predict(SOURCE)
model(SOURCE)
@ -38,11 +37,10 @@ def test_model_fuse():
def test_predict_dir():
model = YOLO(MODEL)
model.predict(source=ROOT / "assets")
model(source=ROOT / "assets")
def test_predict_img():
model = YOLO(MODEL)
img = Image.open(str(SOURCE))
output = model(source=img, save=True, verbose=True) # PIL
@ -106,22 +104,26 @@ def test_export_torchscript():
print(export_formats())
model = YOLO(MODEL)
model.export(format='torchscript')
f = model.export(format='torchscript')
YOLO(f)(SOURCE) # exported model inference
def test_export_onnx():
model = YOLO(MODEL)
model.export(format='onnx')
f = model.export(format='onnx')
YOLO(f)(SOURCE) # exported model inference
def test_export_openvino():
model = YOLO(MODEL)
model.export(format='openvino')
f = model.export(format='openvino')
YOLO(f)(SOURCE) # exported model inference
def test_export_coreml():
model = YOLO(MODEL)
model.export(format='coreml')
# YOLO(f)(SOURCE) # model prediction only supported on macOS
def test_export_paddle(enabled=False):
@ -140,6 +142,7 @@ def test_workflow():
model = YOLO(MODEL)
model.train(data="coco8.yaml", epochs=1, imgsz=32)
model.val()
print(model.metrics)
model.predict(SOURCE)
model.export(format="onnx", opset=12) # export a model to ONNX format
@ -164,6 +167,3 @@ def test_predict_callback_and_setup():
print('test_callback', bs)
boxes = result.boxes # Boxes object for bbox outputs
print(boxes)
test_predict_img()

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.35"
__version__ = "8.0.36"
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils.checks import check_yolo as checks

@ -5,12 +5,12 @@ import requests
from ultralytics.hub.auth import Auth
from ultralytics.hub.session import HubTrainingSession
from ultralytics.hub.utils import split_key
from ultralytics.yolo.engine.exporter import export_formats
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_LIST
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import LOGGER, PREFIX, emojis
# Define all export formats
EXPORT_FORMATS = list(export_formats()['Argument'][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ["ultralytics_tflite", "ultralytics_coreml"]
def start(key=""):
@ -69,7 +69,7 @@ def reset_model(key=""):
def export_model(key="", format="torchscript"):
# Export a model to all formats
assert format in EXPORT_FORMATS, f"Unsupported export format '{format}' passed, valid formats are {EXPORT_FORMATS}"
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
api_key, model_id = split_key(key)
r = requests.post("https://api.ultralytics.com/export",
json={
@ -82,7 +82,7 @@ def export_model(key="", format="torchscript"):
def get_export(key="", format="torchscript"):
# Get an exported model dictionary with download URL
assert format in EXPORT_FORMATS, f"Unsupported export format '{format}' passed, valid formats are {EXPORT_FORMATS}"
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
api_key, model_id = split_key(key)
r = requests.post("https://api.ultralytics.com/get-export",
json={

@ -193,7 +193,7 @@ class AutoBackend(nn.Module):
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:
import tensorflow as tf
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
delegate = {
@ -232,8 +232,10 @@ class AutoBackend(nn.Module):
nhwc = model.runtime.startswith("tensorflow")
'''
else:
raise NotImplementedError(f"ERROR: '{w}' is not a supported format. For supported formats see "
f"https://docs.ultralytics.com/reference/nn/")
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
raise TypeError(f"model='{w}' is not a supported model format. "
"See https://docs.ultralytics.com/tasks/detection/#export for help."
f"\n\n{EXPORT_FORMATS_TABLE}")
# class names
if 'names' not in locals(): # names missing

@ -356,7 +356,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
model.args = args # attach args to model
model.pt_path = weights # attach *.pt file path to model
model.task = guess_model_task(model)
if not hasattr(model, 'stride'):

@ -12,8 +12,8 @@ from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_P
IterableSimpleNamespace, __version__, checks, colorstr, yaml_load, yaml_print)
CLI_HELP_MSG = \
"""
YOLOv8 'yolo' CLI commands use the following syntax:
f"""
Arguments received: {str(['yolo'] + sys.argv[1:])}. Note that Ultralytics 'yolo' commands use the following syntax:
yolo TASK MODE ARGS
@ -64,9 +64,7 @@ CFG_BOOL_KEYS = {
def cfg2dict(cfg):
"""
Convert a configuration object to a dictionary.
This function converts a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
Inputs:
cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary.
@ -143,8 +141,9 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
if mismatched:
string = ''
for x in mismatched:
matches = get_close_matches(x, base)
match_str = f"Similar arguments are {matches}." if matches else ''
matches = get_close_matches(x, base) # key list
matches = [f"{k}={DEFAULT_CFG_DICT[k]}" if DEFAULT_CFG_DICT[k] is not None else k for k in matches] # k=v
match_str = f"Similar arguments are i.e. {matches}." if matches else ''
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
raise SyntaxError(string + CLI_HELP_MSG) from e
@ -265,7 +264,7 @@ def entrypoint(debug=''):
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
elif mode not in modes:
if mode != 'checks':
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {modes}.")
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {modes}.\n{CLI_HELP_MSG}")
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
checks.check_yolo()
return

@ -682,7 +682,8 @@ def v8_transforms(dataset, imgsz, hyp):
# Classification augmentations -----------------------------------------------------------------------------------------
def classify_transforms(size=224):
# Transforms to apply if albumentations not installed
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
if not isinstance(size, int):
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])

@ -48,7 +48,6 @@ TensorFlow.js:
$ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
$ npm start
"""
import contextlib
import json
import os
import platform
@ -74,7 +73,7 @@ from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, __version__, callbacks,
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode, get_latest_opset
MACOS = platform.system() == 'Darwin' # macOS environment
@ -97,6 +96,10 @@ def export_formats():
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
EXPORT_FORMATS_LIST = list(export_formats()['Argument'][1:])
EXPORT_FORMATS_TABLE = str(export_formats())
def try_export(inner_func):
# YOLOv8 export decorator, i..e @try_export
inner_args = get_default_args(inner_func)
@ -244,7 +247,7 @@ class Exporter:
agnostic_nms=self.args.agnostic_nms)
if edgetpu:
f[8], _ = self._export_edgetpu()
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape))
self._add_tflite_metadata(f[8] or f[7])
if tfjs:
f[9], _ = self._export_tfjs()
if paddle: # PaddlePaddle
@ -253,11 +256,11 @@ class Exporter:
# Finish
f = [str(x) for x in f if x] # filter out '' and None
if any(f):
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
f = str(Path(f[-1]))
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nPredict: yolo task={model.task} mode=predict model={f[-1]} {s}"
f"\nValidate: yolo task={model.task} mode=val model={f[-1]} {s}"
f"\nPredict: yolo task={model.task} mode=predict model={f}"
f"\nValidate: yolo task={model.task} mode=val model={f}"
f"\nVisualize: https://netron.app")
self.run_callbacks("on_export_end")
@ -304,7 +307,7 @@ class Exporter:
self.im.cpu() if dynamic else self.im,
f,
verbose=False,
opset_version=self.args.opset,
opset_version=self.args.opset or get_latest_opset(),
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
input_names=['images'],
output_names=output_names,
@ -507,6 +510,10 @@ class Exporter:
# Export to TF SavedModel
subprocess.run(f'onnx2tf -i {onnx} --output_signaturedefs -o {f}', shell=True)
# Add TFLite metadata
for tflite_file in Path(f).rglob('*.tflite'):
self._add_tflite_metadata(tflite_file)
# Load saved_model
keras_model = tf.saved_model.load(f, tags=None, options=None)
@ -661,44 +668,47 @@ class Exporter:
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}}}',
r'{"outputs": {"Identity": {"name": "Identity"}, '
r'"Identity_1": {"name": "Identity_1"}, '
r'"Identity_2": {"name": "Identity_2"}, '
r'"Identity_3": {"name": "Identity_3"}}}', f_json.read_text())
r'"Identity_3": {"name": "Identity_3"}}}',
f_json.read_text(),
)
j.write(subst)
return f, None
def _add_tflite_metadata(self, file, num_outputs):
def _add_tflite_metadata(self, file):
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
with contextlib.suppress(ImportError):
# check_requirements('tflite_support')
from tflite_support import flatbuffers # noqa
from tflite_support import metadata as _metadata # noqa
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
tmp_file = Path('/tmp/meta.txt')
with open(tmp_file, 'w') as meta_f:
meta_f.write(str(self.metadata))
model_meta = _metadata_fb.ModelMetadataT()
label_file = _metadata_fb.AssociatedFileT()
label_file.name = tmp_file.name
model_meta.associatedFiles = [label_file]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
populator = _metadata.MetadataPopulator.with_model_file(file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files([str(tmp_file)])
populator.populate()
tmp_file.unlink()
check_requirements('tflite_support')
from tflite_support import flatbuffers # noqa
from tflite_support import metadata as _metadata # noqa
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
tmp_file = Path('/tmp/meta.txt')
with open(tmp_file, 'w') as meta_f:
meta_f.write(str(self.metadata))
model_meta = _metadata_fb.ModelMetadataT()
label_file = _metadata_fb.AssociatedFileT()
label_file.name = tmp_file.name
model_meta.associatedFiles = [label_file]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * len(self.output_shape)
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
populator = _metadata.MetadataPopulator.with_model_file(file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files([str(tmp_file)])
populator.populate()
tmp_file.unlink()
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
# YOLOv8 CoreML pipeline

@ -6,11 +6,11 @@ from typing import List
from ultralytics import yolo # noqa
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
guess_model_task)
guess_model_task, nn)
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, callbacks, yaml_load
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
@ -55,19 +55,16 @@ class YOLO:
self.cfg = None # if loaded from *.yaml
self.ckpt_path = None
self.overrides = {} # overrides for trainer object
self.metrics_data = None
# Load or create new YOLO model
suffix = Path(model).suffix
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
try:
if suffix == '.yaml':
self._new(model)
else:
self._load(model)
except Exception as e:
raise NotImplementedError(f"Unable to load model='{model}'. "
f"As an example try model='yolov8n.pt' or model='yolov8n.yaml'") from e
if suffix == '.yaml':
self._new(model)
else:
self._load(model)
def __call__(self, source=None, stream=False, **kwargs):
return self.predict(source, stream, **kwargs)
@ -100,15 +97,27 @@ class YOLO:
self.overrides = self.model.args
self._reset_ckpt_args(self.overrides)
else:
check_file(weights)
self.model, self.ckpt = weights, None
self.task = guess_model_task(weights)
self.ckpt_path = weights
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
def _check_is_pytorch_model(self):
"""
Raises TypeError is model is not a PyTorch model
"""
if not isinstance(self.model, nn.Module):
raise TypeError(f"model='{self.model}' must be a PyTorch model, but is a different type. PyTorch models "
f"can be used to train, val, predict and export, i.e. "
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
def reset(self):
"""
Resets the model modules.
"""
self._check_is_pytorch_model()
for m in self.model.modules():
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
@ -122,9 +131,11 @@ class YOLO:
Args:
verbose (bool): Controls verbosity.
"""
self._check_is_pytorch_model()
self.model.info(verbose=verbose)
def fuse(self):
self._check_is_pytorch_model()
self.model.fuse()
def predict(self, source=None, stream=False, **kwargs):
@ -176,6 +187,8 @@ class YOLO:
validator = self.ValidatorClass(args=args)
validator(model=self.model)
self.metrics_data = validator.metrics
return validator.metrics
@smart_inference_mode()
@ -186,7 +199,7 @@ class YOLO:
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
self._check_is_pytorch_model()
overrides = self.overrides.copy()
overrides.update(kwargs)
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
@ -196,7 +209,7 @@ class YOLO:
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
exporter = Exporter(overrides=args)
exporter(model=self.model)
return exporter(model=self.model)
def train(self, **kwargs):
"""
@ -205,6 +218,7 @@ class YOLO:
Args:
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
overrides = self.overrides.copy()
overrides.update(kwargs)
if kwargs.get("cfg"):
@ -226,6 +240,7 @@ class YOLO:
if RANK in {0, -1}:
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics_data = self.trainer.validator.metrics
def to(self, device):
"""
@ -234,15 +249,14 @@ class YOLO:
Args:
device (str): device
"""
self._check_is_pytorch_model()
self.model.to(device)
def _assign_ops_from_task(self):
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
# warning: eval is unsafe. Use with caution
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
return model_class, trainer_class, validator_class, predictor_class
@property
@ -250,7 +264,7 @@ class YOLO:
"""
Returns class names of the loaded model.
"""
return self.model.names
return self.model.names if hasattr(self.model, 'names') else None
@property
def transforms(self):
@ -259,6 +273,16 @@ class YOLO:
"""
return self.model.transforms if hasattr(self.model, 'transforms') else None
@property
def metrics(self):
"""
Returns metrics if computed
"""
if not self.metrics_data:
LOGGER.info("No metrics data found! Run training or validation operation first.")
return self.metrics_data
@staticmethod
def add_callback(event: str, func):
"""
@ -269,5 +293,5 @@ class YOLO:
@staticmethod
def _reset_ckpt_args(args):
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots':
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset':
args.pop(arg, None)

@ -35,6 +35,7 @@ import torch
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data import load_inference_source
from ultralytics.yolo.data.augment import classify_transforms
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path
@ -121,8 +122,12 @@ class BasePredictor:
def setup_source(self, source):
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
if self.args.task == 'classify':
transforms = getattr(self.model.model, 'transforms', classify_transforms(self.imgsz[0]))
else: # predict, segment
transforms = None
self.dataset = load_inference_source(source=source,
transforms=getattr(self.model.model, 'transforms', None),
transforms=transforms,
imgsz=self.imgsz,
vid_stride=self.args.vid_stride,
stride=self.model.stride,

@ -217,19 +217,18 @@ class BaseTrainer:
# Optimizer
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
self.optimizer = self.build_optimizer(model=self.model,
name=self.args.optimizer,
lr=self.args.lr0,
momentum=self.args.momentum,
decay=self.args.weight_decay)
decay=weight_decay)
# Scheduler
if self.args.cos_lr:
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
else:
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
# dataloaders
@ -242,6 +241,7 @@ class BaseTrainer:
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model)
self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.run_callbacks("on_pretrain_routine_end")
def _do_train(self, rank=-1, world_size=1):
@ -555,6 +555,12 @@ class BaseTrainer:
self.epochs += ckpt['epoch'] # finetune additional epochs
self.best_fitness = best_fitness
self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic):
self.console.info("Closing dataloader mosaic")
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
@staticmethod
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):

@ -234,17 +234,17 @@ def check_yolov5u_filename(file: str):
return file
def check_file(file, suffix=''):
def check_file(file, suffix='', download=True):
# Search/download file (if necessary) and return path
check_suffix(file, suffix) # optional
file = str(file) # convert to string
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
if not file or ('://' not in file and Path(file).is_file()): # exists ('://' check required in Windows Python<3.10)
if not file or ('://' not in file and Path(file).exists()): # exists ('://' check required in Windows Python<3.10)
return file
elif file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
url = file # warning: Pathlib turns :// -> :/
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
if Path(file).is_file():
if Path(file).exists():
LOGGER.info(f'Found {url} locally at {file}') # file already exists
else:
downloads.safe_download(url=url, file=file, unzip=False)

@ -44,11 +44,17 @@ def generate_ddp_file(trainer):
def generate_ddp_command(world_size, trainer):
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
file = generate_ddp_file(trainer) if sys.argv[0].endswith('yolo') else os.path.abspath(sys.argv[0])
# Get file and args (do not use sys.argv due to security vulnerability)
exclude_args = ['save_dir']
args = [f"{k}={v}" for k, v in vars(trainer.args).items() if k not in exclude_args]
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
# Build command
torch_distributed_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
cmd = [
sys.executable, "-m", torch_distributed_cmd, "--nproc_per_node", f"{world_size}", "--master_port",
f"{find_free_network_port()}", file] + sys.argv[1:]
f"{find_free_network_port()}", file] + args
return cmd, file

@ -242,6 +242,11 @@ def copy_attr(a, b, include=(), exclude=()):
setattr(a, k, v)
def get_latest_opset():
# Return max supported ONNX opset by this version of torch
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) # opset
def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}

Loading…
Cancel
Save