diff --git a/docs/modes/export.md b/docs/modes/export.md
index 42e4527..86bc686 100644
--- a/docs/modes/export.md
+++ b/docs/modes/export.md
@@ -1,7 +1,7 @@
---
comments: true
description: 'Export mode: Create a deployment-ready YOLOv8 model by converting it to various formats. Export to ONNX or OpenVINO for up to 3x CPU speedup.'
-keywords: ultralytics docs, YOLOv8, export YOLOv8, YOLOv8 model deployment, exporting YOLOv8, ONNX, OpenVINO, TensorRT, CoreML, TF SavedModel, PaddlePaddle, TorchScript, ONNX format, OpenVINO format, TensorRT format, CoreML format, TF SavedModel format, PaddlePaddle format
+keywords: ultralytics docs, YOLOv8, export YOLOv8, YOLOv8 model deployment, exporting YOLOv8, ONNX, OpenVINO, TensorRT, CoreML, TF SavedModel, PaddlePaddle, TorchScript, ONNX format, OpenVINO format, TensorRT format, CoreML format, TF SavedModel format, PaddlePaddle format, Tencent NCNN, NCNN
---
@@ -84,4 +84,5 @@ i.e. `format='onnx'` or `format='engine'`.
| [TF Lite](https://www.tensorflow.org/lite) | `tflite` | `yolov8n.tflite` | ✅ | `imgsz`, `half`, `int8` |
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` | ✅ | `imgsz` |
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` | ✅ | `imgsz` |
-| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` | ✅ | `imgsz` |
\ No newline at end of file
+| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` | ✅ | `imgsz` |
+| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n_ncnn_model/` | ✅ | `imgsz` |
\ No newline at end of file
diff --git a/docs/modes/val.md b/docs/modes/val.md
index 79fdf6f..4ffff73 100644
--- a/docs/modes/val.md
+++ b/docs/modes/val.md
@@ -6,9 +6,7 @@ keywords: Ultralytics, YOLO, YOLOv8, Val, Validation, Hyperparameters, Performan
-**Val mode** is used for validating a YOLOv8 model after it has been trained. In this mode, the model is evaluated on a
-validation set to measure its accuracy and generalization performance. This mode can be used to tune the hyperparameters
-of the model to improve its performance.
+**Val mode** is used for validating a YOLOv8 model after it has been trained. In this mode, the model is evaluated on a validation set to measure its accuracy and generalization performance. This mode can be used to tune the hyperparameters of the model to improve its performance.
!!! tip "Tip"
@@ -16,8 +14,7 @@ of the model to improve its performance.
## Usage Examples
-Validate trained YOLOv8n model accuracy on the COCO128 dataset. No argument need to passed as the `model` retains it's
-training `data` and arguments as model attributes. See Arguments section below for a full list of export arguments.
+Validate trained YOLOv8n model accuracy on the COCO128 dataset. No argument need to passed as the `model` retains it's training `data` and arguments as model attributes. See Arguments section below for a full list of export arguments.
!!! example ""
@@ -46,13 +43,7 @@ training `data` and arguments as model attributes. See Arguments section below f
## Arguments
-Validation settings for YOLO models refer to the various hyperparameters and configurations used to
-evaluate the model's performance on a validation dataset. These settings can affect the model's performance, speed, and
-accuracy. Some common YOLO validation settings include the batch size, the frequency with which validation is performed
-during training, and the metrics used to evaluate the model's performance. Other factors that may affect the validation
-process include the size and composition of the validation dataset and the specific task the model is being used for. It
-is important to carefully tune and experiment with these settings to ensure that the model is performing well on the
-validation dataset and to detect and prevent overfitting.
+Validation settings for YOLO models refer to the various hyperparameters and configurations used to evaluate the model's performance on a validation dataset. These settings can affect the model's performance, speed, and accuracy. Some common YOLO validation settings include the batch size, the frequency with which validation is performed during training, and the metrics used to evaluate the model's performance. Other factors that may affect the validation process include the size and composition of the validation dataset and the specific task the model is being used for. It is important to carefully tune and experiment with these settings to ensure that the model is performing well on the validation dataset and to detect and prevent overfitting.
| Key | Value | Description |
|---------------|---------|--------------------------------------------------------------------|
@@ -70,23 +61,4 @@ validation dataset and to detect and prevent overfitting.
| `plots` | `False` | show plots during training |
| `rect` | `False` | rectangular val with each batch collated for minimum padding |
| `split` | `val` | dataset split to use for validation, i.e. 'val', 'test' or 'train' |
-
-## Export Formats
-
-Available YOLOv8 export formats are in the table below. You can export to any format using the `format` argument,
-i.e. `format='onnx'` or `format='engine'`.
-
-| Format | `format` Argument | Model | Metadata | Arguments |
-|--------------------------------------------------------------------|-------------------|---------------------------|----------|-----------------------------------------------------|
-| [PyTorch](https://pytorch.org/) | - | `yolov8n.pt` | ✅ | - |
-| [TorchScript](https://pytorch.org/docs/stable/jit.html) | `torchscript` | `yolov8n.torchscript` | ✅ | `imgsz`, `optimize` |
-| [ONNX](https://onnx.ai/) | `onnx` | `yolov8n.onnx` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `opset` |
-| [OpenVINO](https://docs.openvino.ai/latest/index.html) | `openvino` | `yolov8n_openvino_model/` | ✅ | `imgsz`, `half` |
-| [TensorRT](https://developer.nvidia.com/tensorrt) | `engine` | `yolov8n.engine` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `workspace` |
-| [CoreML](https://github.com/apple/coremltools) | `coreml` | `yolov8n.mlmodel` | ✅ | `imgsz`, `half`, `int8`, `nms` |
-| [TF SavedModel](https://www.tensorflow.org/guide/saved_model) | `saved_model` | `yolov8n_saved_model/` | ✅ | `imgsz`, `keras` |
-| [TF GraphDef](https://www.tensorflow.org/api_docs/python/tf/Graph) | `pb` | `yolov8n.pb` | ❌ | `imgsz` |
-| [TF Lite](https://www.tensorflow.org/lite) | `tflite` | `yolov8n.tflite` | ✅ | `imgsz`, `half`, `int8` |
-| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` | ✅ | `imgsz` |
-| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` | ✅ | `imgsz` |
-| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` | ✅ | `imgsz` |
\ No newline at end of file
+|
\ No newline at end of file
diff --git a/docs/tasks/classify.md b/docs/tasks/classify.md
index fe0b939..dfe7e07 100644
--- a/docs/tasks/classify.md
+++ b/docs/tasks/classify.md
@@ -176,5 +176,6 @@ i.e. `yolo predict model=yolov8n-cls.onnx`. Usage examples are shown for your mo
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n-cls_edgetpu.tflite` | ✅ | `imgsz` |
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n-cls_web_model/` | ✅ | `imgsz` |
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n-cls_paddle_model/` | ✅ | `imgsz` |
+| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n-cls_ncnn_model/` | ✅ | `imgsz` |
See full `export` details in the [Export](https://docs.ultralytics.com/modes/export/) page.
\ No newline at end of file
diff --git a/docs/tasks/detect.md b/docs/tasks/detect.md
index 35a3d44..68bef46 100644
--- a/docs/tasks/detect.md
+++ b/docs/tasks/detect.md
@@ -167,5 +167,6 @@ Available YOLOv8 export formats are in the table below. You can predict or valid
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` | ✅ | `imgsz` |
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` | ✅ | `imgsz` |
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` | ✅ | `imgsz` |
+| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n_ncnn_model/` | ✅ | `imgsz` |
See full `export` details in the [Export](https://docs.ultralytics.com/modes/export/) page.
\ No newline at end of file
diff --git a/docs/tasks/pose.md b/docs/tasks/pose.md
index 094f95b..79d35db 100644
--- a/docs/tasks/pose.md
+++ b/docs/tasks/pose.md
@@ -181,5 +181,6 @@ i.e. `yolo predict model=yolov8n-pose.onnx`. Usage examples are shown for your m
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n-pose_edgetpu.tflite` | ✅ | `imgsz` |
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n-pose_web_model/` | ✅ | `imgsz` |
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n-pose_paddle_model/` | ✅ | `imgsz` |
+| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n-pose_ncnn_model/` | ✅ | `imgsz` |
See full `export` details in the [Export](https://docs.ultralytics.com/modes/export/) page.
\ No newline at end of file
diff --git a/docs/tasks/segment.md b/docs/tasks/segment.md
index 4f9192f..cb1b30a 100644
--- a/docs/tasks/segment.md
+++ b/docs/tasks/segment.md
@@ -181,5 +181,6 @@ i.e. `yolo predict model=yolov8n-seg.onnx`. Usage examples are shown for your mo
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n-seg_edgetpu.tflite` | ✅ | `imgsz` |
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n-seg_web_model/` | ✅ | `imgsz` |
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n-seg_paddle_model/` | ✅ | `imgsz` |
+| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n-seg_ncnn_model/` | ✅ | `imgsz` |
See full `export` details in the [Export](https://docs.ultralytics.com/modes/export/) page.
\ No newline at end of file
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 5351f39..d4839cf 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = '8.0.128'
+__version__ = '8.0.129'
from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR
diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py
index 6f1695d..e277957 100644
--- a/ultralytics/nn/autobackend.py
+++ b/ultralytics/nn/autobackend.py
@@ -79,7 +79,8 @@ class AutoBackend(nn.Module):
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
nn_module = isinstance(weights, torch.nn.Module)
- pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
+ pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
+ self._model_type(w)
fp16 &= pt or jit or onnx or engine or nn_module or triton # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
stride = 32 # default stride
@@ -237,7 +238,7 @@ class AutoBackend(nn.Module):
meta_file = model.namelist()[0]
metadata = ast.literal_eval(model.read(meta_file).decode('utf-8'))
elif tfjs: # TF.js
- raise NotImplementedError('YOLOv8 TF.js inference is not supported')
+ raise NotImplementedError('YOLOv8 TF.js inference is not currently supported.')
elif paddle: # PaddlePaddle
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
@@ -252,6 +253,8 @@ class AutoBackend(nn.Module):
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
metadata = w.parents[1] / 'metadata.yaml'
+ elif ncnn: # PaddlePaddle
+ raise NotImplementedError('YOLOv8 NCNN inference is not currently supported.')
elif triton: # NVIDIA Triton Inference Server
LOGGER.info('Triton Inference Server not supported...')
'''
diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py
index 7fa6cf4..0a70ff4 100644
--- a/ultralytics/yolo/engine/exporter.py
+++ b/ultralytics/yolo/engine/exporter.py
@@ -16,6 +16,7 @@ TensorFlow Lite | `tflite` | yolov8n.tflite
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
TensorFlow.js | `tfjs` | yolov8n_web_model/
PaddlePaddle | `paddle` | yolov8n_paddle_model/
+NCNN | `ncnn` | yolov8n_ncnn_model/
Requirements:
$ pip install ultralytics[export]
@@ -50,6 +51,7 @@ TensorFlow.js:
import json
import os
import platform
+import shutil
import subprocess
import time
import warnings
@@ -62,9 +64,10 @@ from ultralytics.nn.autobackend import check_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
from ultralytics.yolo.cfg import get_cfg
-from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, __version__, callbacks, colorstr,
+from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, __version__, callbacks, colorstr,
get_default_args, yaml_save)
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version
+from ultralytics.yolo.utils.downloads import attempt_download_asset, get_github_assets
from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
@@ -87,7 +90,8 @@ def export_formats():
['TensorFlow Lite', 'tflite', '.tflite', True, False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
['TensorFlow.js', 'tfjs', '_web_model', True, False],
- ['PaddlePaddle', 'paddle', '_paddle_model', True, True], ]
+ ['PaddlePaddle', 'paddle', '_paddle_model', True, True],
+ ['NCNN', 'ncnn', '_ncnn_model', True, True], ]
return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
@@ -153,7 +157,7 @@ class Exporter:
flags = [x == format for x in fmts]
if sum(flags) != 1:
raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}")
- jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
+ jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
# Load PyTorch model
self.device = select_device('cpu' if self.args.device is None else self.args.device)
@@ -231,7 +235,7 @@ class Exporter:
# Exports
f = [''] * len(fmts) # exported filenames
- if jit: # TorchScript
+ if jit or ncnn: # TorchScript
f[0], _ = self.export_torchscript()
if engine: # TensorRT required before ONNX
f[1], _ = self.export_engine()
@@ -254,6 +258,8 @@ class Exporter:
f[9], _ = self.export_tfjs()
if paddle: # PaddlePaddle
f[10], _ = self.export_paddle()
+ if ncnn: # NCNN
+ f[11], _ = self.export_ncnn()
# Finish
f = [str(x) for x in f if x] # filter out '' and None
@@ -394,6 +400,57 @@ class Exporter:
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
return f, None
+ @try_export
+ def export_ncnn(self, prefix=colorstr('NCNN:')):
+ """
+ YOLOv8 NCNN export using PNNX https://github.com/pnnx/pnnx.
+ """
+ check_requirements('ncnn') # requires NCNN
+ import ncnn # noqa
+
+ LOGGER.info(f'\n{prefix} starting export with NCNN {ncnn.__version__}...')
+ f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}'))
+ f_ts = str(self.file.with_suffix('.torchscript'))
+
+ if Path('./pnnx').is_file():
+ pnnx = './pnnx'
+ elif (ROOT / 'pnnx').is_file():
+ pnnx = ROOT / 'pnnx'
+ else:
+ LOGGER.warning(
+ f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
+ 'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
+ f'or in {ROOT}. See PNNX repo for full installation instructions.')
+ _, assets = get_github_assets(repo='pnnx/pnnx')
+ asset = [x for x in assets if ('macos' if MACOS else 'ubuntu' if LINUX else 'windows') in x][0]
+ attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
+ unzip_dir = Path(asset).with_suffix('')
+ pnnx = ROOT / 'pnnx' # new location
+ (unzip_dir / 'pnnx').rename(pnnx) # move binary to ROOT
+ shutil.rmtree(unzip_dir) # delete unzip dir
+ Path(asset).unlink() # delete zip
+ pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
+
+ cmd = [
+ str(pnnx),
+ f_ts,
+ f'pnnxparam={f / "model.pnnx.param"}',
+ f'pnnxbin={f / "model.pnnx.bin"}',
+ f'pnnxpy={f / "model_pnnx.py"}',
+ f'pnnxonnx={f / "model.pnnx.onnx"}',
+ f'ncnnparam={f / "model.ncnn.param"}',
+ f'ncnnbin={f / "model.ncnn.bin"}',
+ f'ncnnpy={f / "model_ncnn.py"}',
+ f'fp16={int(self.args.half)}',
+ f'device={self.device.type}',
+ f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ]
+ f.mkdir(exist_ok=True) # make ncnn_model directory
+ LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
+ subprocess.run(cmd, check=True)
+
+ yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
+ return str(f), None
+
@try_export
def export_coreml(self, prefix=colorstr('CoreML:')):
"""YOLOv8 CoreML export."""
diff --git a/ultralytics/yolo/utils/benchmarks.py b/ultralytics/yolo/utils/benchmarks.py
index 654847b..f9e8f00 100644
--- a/ultralytics/yolo/utils/benchmarks.py
+++ b/ultralytics/yolo/utils/benchmarks.py
@@ -21,6 +21,7 @@ TensorFlow Lite | `tflite` | yolov8n.tflite
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
TensorFlow.js | `tfjs` | yolov8n_web_model/
PaddlePaddle | `paddle` | yolov8n_paddle_model/
+NCNN | `ncnn` | yolov8n_ncnn_model/
"""
import glob
@@ -98,7 +99,7 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
# Predict
assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
- assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
+ assert i not in (9, 10, 12), 'inference not supported' # Edge TPU, TF.js and NCNN are unsupported
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
if not (ROOT / 'assets/bus.jpg').exists():
download(url='https://ultralytics.com/images/bus.jpg', dir=ROOT / 'assets')
diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py
index 2188aa8..0ea860c 100644
--- a/ultralytics/yolo/utils/checks.py
+++ b/ultralytics/yolo/utils/checks.py
@@ -8,6 +8,7 @@ import platform
import re
import shutil
import subprocess
+import time
from pathlib import Path
from typing import Optional
@@ -235,13 +236,16 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
if s:
if install and AUTOINSTALL: # check environment variable
- LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
+ pkgs = file or requirements # missing packages
+ LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
try:
+ t = time.time()
assert is_online(), 'AutoUpdate skipped (offline)'
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
- s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
- f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
- LOGGER.info(s)
+ dt = time.time() - t
+ LOGGER.info(
+ f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
+ f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n")
except Exception as e:
LOGGER.warning(f'{prefix} ❌ {e}')
return False
diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py
index eaa8f0d..53f58cf 100644
--- a/ultralytics/yolo/utils/downloads.py
+++ b/ultralytics/yolo/utils/downloads.py
@@ -189,7 +189,7 @@ def safe_download(url,
if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'):
unzip_dir = dir or f.parent # unzip to dir if provided else unzip in place
- LOGGER.info(f'Unzipping {f} to {unzip_dir}...')
+ LOGGER.info(f'Unzipping {f} to {unzip_dir.absolute()}...')
if is_zipfile(f):
unzip_dir = unzip_file(file=f, path=unzip_dir) # unzip
elif f.suffix == '.tar':
@@ -201,17 +201,18 @@ def safe_download(url,
return unzip_dir
+def get_github_assets(repo='ultralytics/assets', version='latest'):
+ """Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])."""
+ if version != 'latest':
+ version = f'tags/{version}' # i.e. tags/v6.2
+ response = requests.get(f'https://api.github.com/repos/{repo}/releases/{version}').json() # github api
+ return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
+
+
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
"""Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc."""
from ultralytics.yolo.utils import SETTINGS # scoped for circular import
- def github_assets(repository, version='latest'):
- """Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])."""
- if version != 'latest':
- version = f'tags/{version}' # i.e. tags/v6.2
- response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
- return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
-
# YOLOv3/5u updates
file = str(file)
file = checks.check_yolov5u_filename(file)
@@ -235,10 +236,10 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
# GitHub assets
assets = GITHUB_ASSET_NAMES
try:
- tag, assets = github_assets(repo, release)
+ tag, assets = get_github_assets(repo, release)
except Exception:
try:
- tag, assets = github_assets(repo) # latest release
+ tag, assets = get_github_assets(repo) # latest release
except Exception:
try:
tag = subprocess.check_output(['git', 'tag']).decode().split()[-1]