ultralytics 8.0.129
add YOLOv8 Tencent NCNN export (#3529)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -16,6 +16,7 @@ TensorFlow Lite | `tflite` | yolov8n.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
NCNN | `ncnn` | yolov8n_ncnn_model/
|
||||
|
||||
Requirements:
|
||||
$ pip install ultralytics[export]
|
||||
@ -50,6 +51,7 @@ TensorFlow.js:
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import warnings
|
||||
@ -62,9 +64,10 @@ from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, __version__, callbacks, colorstr,
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, __version__, callbacks, colorstr,
|
||||
get_default_args, yaml_save)
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset, get_github_assets
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
|
||||
@ -87,7 +90,8 @@ def export_formats():
|
||||
['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
||||
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
|
||||
['TensorFlow.js', 'tfjs', '_web_model', True, False],
|
||||
['PaddlePaddle', 'paddle', '_paddle_model', True, True], ]
|
||||
['PaddlePaddle', 'paddle', '_paddle_model', True, True],
|
||||
['NCNN', 'ncnn', '_ncnn_model', True, True], ]
|
||||
return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
||||
|
||||
|
||||
@ -153,7 +157,7 @@ class Exporter:
|
||||
flags = [x == format for x in fmts]
|
||||
if sum(flags) != 1:
|
||||
raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}")
|
||||
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
|
||||
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
|
||||
|
||||
# Load PyTorch model
|
||||
self.device = select_device('cpu' if self.args.device is None else self.args.device)
|
||||
@ -231,7 +235,7 @@ class Exporter:
|
||||
|
||||
# Exports
|
||||
f = [''] * len(fmts) # exported filenames
|
||||
if jit: # TorchScript
|
||||
if jit or ncnn: # TorchScript
|
||||
f[0], _ = self.export_torchscript()
|
||||
if engine: # TensorRT required before ONNX
|
||||
f[1], _ = self.export_engine()
|
||||
@ -254,6 +258,8 @@ class Exporter:
|
||||
f[9], _ = self.export_tfjs()
|
||||
if paddle: # PaddlePaddle
|
||||
f[10], _ = self.export_paddle()
|
||||
if ncnn: # NCNN
|
||||
f[11], _ = self.export_ncnn()
|
||||
|
||||
# Finish
|
||||
f = [str(x) for x in f if x] # filter out '' and None
|
||||
@ -394,6 +400,57 @@ class Exporter:
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_ncnn(self, prefix=colorstr('NCNN:')):
|
||||
"""
|
||||
YOLOv8 NCNN export using PNNX https://github.com/pnnx/pnnx.
|
||||
"""
|
||||
check_requirements('ncnn') # requires NCNN
|
||||
import ncnn # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with NCNN {ncnn.__version__}...')
|
||||
f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}'))
|
||||
f_ts = str(self.file.with_suffix('.torchscript'))
|
||||
|
||||
if Path('./pnnx').is_file():
|
||||
pnnx = './pnnx'
|
||||
elif (ROOT / 'pnnx').is_file():
|
||||
pnnx = ROOT / 'pnnx'
|
||||
else:
|
||||
LOGGER.warning(
|
||||
f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
|
||||
'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
|
||||
f'or in {ROOT}. See PNNX repo for full installation instructions.')
|
||||
_, assets = get_github_assets(repo='pnnx/pnnx')
|
||||
asset = [x for x in assets if ('macos' if MACOS else 'ubuntu' if LINUX else 'windows') in x][0]
|
||||
attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
|
||||
unzip_dir = Path(asset).with_suffix('')
|
||||
pnnx = ROOT / 'pnnx' # new location
|
||||
(unzip_dir / 'pnnx').rename(pnnx) # move binary to ROOT
|
||||
shutil.rmtree(unzip_dir) # delete unzip dir
|
||||
Path(asset).unlink() # delete zip
|
||||
pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
|
||||
|
||||
cmd = [
|
||||
str(pnnx),
|
||||
f_ts,
|
||||
f'pnnxparam={f / "model.pnnx.param"}',
|
||||
f'pnnxbin={f / "model.pnnx.bin"}',
|
||||
f'pnnxpy={f / "model_pnnx.py"}',
|
||||
f'pnnxonnx={f / "model.pnnx.onnx"}',
|
||||
f'ncnnparam={f / "model.ncnn.param"}',
|
||||
f'ncnnbin={f / "model.ncnn.bin"}',
|
||||
f'ncnnpy={f / "model_ncnn.py"}',
|
||||
f'fp16={int(self.args.half)}',
|
||||
f'device={self.device.type}',
|
||||
f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ]
|
||||
f.mkdir(exist_ok=True) # make ncnn_model directory
|
||||
LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return str(f), None
|
||||
|
||||
@try_export
|
||||
def export_coreml(self, prefix=colorstr('CoreML:')):
|
||||
"""YOLOv8 CoreML export."""
|
||||
|
@ -21,6 +21,7 @@ TensorFlow Lite | `tflite` | yolov8n.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
NCNN | `ncnn` | yolov8n_ncnn_model/
|
||||
"""
|
||||
|
||||
import glob
|
||||
@ -98,7 +99,7 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
|
||||
# Predict
|
||||
assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
|
||||
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
|
||||
assert i not in (9, 10, 12), 'inference not supported' # Edge TPU, TF.js and NCNN are unsupported
|
||||
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
|
||||
if not (ROOT / 'assets/bus.jpg').exists():
|
||||
download(url='https://ultralytics.com/images/bus.jpg', dir=ROOT / 'assets')
|
||||
|
@ -8,6 +8,7 @@ import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -235,13 +236,16 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
|
||||
if s:
|
||||
if install and AUTOINSTALL: # check environment variable
|
||||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
||||
pkgs = file or requirements # missing packages
|
||||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
|
||||
try:
|
||||
t = time.time()
|
||||
assert is_online(), 'AutoUpdate skipped (offline)'
|
||||
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
|
||||
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
||||
LOGGER.info(s)
|
||||
dt = time.time() - t
|
||||
LOGGER.info(
|
||||
f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{prefix} ❌ {e}')
|
||||
return False
|
||||
|
@ -189,7 +189,7 @@ def safe_download(url,
|
||||
|
||||
if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'):
|
||||
unzip_dir = dir or f.parent # unzip to dir if provided else unzip in place
|
||||
LOGGER.info(f'Unzipping {f} to {unzip_dir}...')
|
||||
LOGGER.info(f'Unzipping {f} to {unzip_dir.absolute()}...')
|
||||
if is_zipfile(f):
|
||||
unzip_dir = unzip_file(file=f, path=unzip_dir) # unzip
|
||||
elif f.suffix == '.tar':
|
||||
@ -201,17 +201,18 @@ def safe_download(url,
|
||||
return unzip_dir
|
||||
|
||||
|
||||
def get_github_assets(repo='ultralytics/assets', version='latest'):
|
||||
"""Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])."""
|
||||
if version != 'latest':
|
||||
version = f'tags/{version}' # i.e. tags/v6.2
|
||||
response = requests.get(f'https://api.github.com/repos/{repo}/releases/{version}').json() # github api
|
||||
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
||||
|
||||
|
||||
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
"""Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc."""
|
||||
from ultralytics.yolo.utils import SETTINGS # scoped for circular import
|
||||
|
||||
def github_assets(repository, version='latest'):
|
||||
"""Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])."""
|
||||
if version != 'latest':
|
||||
version = f'tags/{version}' # i.e. tags/v6.2
|
||||
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
||||
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
||||
|
||||
# YOLOv3/5u updates
|
||||
file = str(file)
|
||||
file = checks.check_yolov5u_filename(file)
|
||||
@ -235,10 +236,10 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
# GitHub assets
|
||||
assets = GITHUB_ASSET_NAMES
|
||||
try:
|
||||
tag, assets = github_assets(repo, release)
|
||||
tag, assets = get_github_assets(repo, release)
|
||||
except Exception:
|
||||
try:
|
||||
tag, assets = github_assets(repo) # latest release
|
||||
tag, assets = get_github_assets(repo) # latest release
|
||||
except Exception:
|
||||
try:
|
||||
tag = subprocess.check_output(['git', 'tag']).decode().split()[-1]
|
||||
|
Reference in New Issue
Block a user