ultralytics 8.0.46
TFLite and Benchmarks updates (#1141)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -294,7 +294,7 @@ class Exporter:
|
||||
# YOLOv8 ONNX export
|
||||
requirements = ['onnx>=1.12.0']
|
||||
if self.args.simplify:
|
||||
requirements += ['onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
|
||||
requirements += ['onnxsim>=0.4.17', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
|
||||
check_requirements(requirements)
|
||||
import onnx # noqa
|
||||
|
||||
@ -513,8 +513,8 @@ class Exporter:
|
||||
cuda = torch.cuda.is_available()
|
||||
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
|
||||
import tensorflow as tf # noqa
|
||||
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support',
|
||||
'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
|
||||
check_requirements(('onnx', 'onnx2tf>=1.7.7', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.17', 'onnx_graphsurgeon>=0.3.26',
|
||||
'tflite_support', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
|
||||
cmds='--extra-index-url https://pypi.ngc.nvidia.com')
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
@ -529,7 +529,7 @@ class Exporter:
|
||||
|
||||
# Export to TF
|
||||
int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
|
||||
cmd = f'onnx2tf -i {f_onnx} -o {f} --non_verbose {int8}'
|
||||
cmd = f'onnx2tf -i {f_onnx} -o {f} -nuo --non_verbose {int8}'
|
||||
LOGGER.info(f'\n{prefix} running {cmd}')
|
||||
subprocess.run(cmd, shell=True)
|
||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
|
@ -9,8 +9,9 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
|
||||
guess_model_task, nn)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
||||
is_git_dir, is_pip_package, yaml_load)
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
|
||||
@ -150,6 +151,13 @@ class YOLO:
|
||||
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
||||
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
||||
|
||||
def _check_pip_update(self):
|
||||
"""
|
||||
Inform user of ultralytics package update availability
|
||||
"""
|
||||
if is_pip_package():
|
||||
check_pip_update()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the model modules.
|
||||
@ -189,6 +197,10 @@ class YOLO:
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
@ -251,11 +263,12 @@ class YOLO:
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
from ultralytics.yolo.utils.benchmarks import run_benchmarks
|
||||
self._check_is_pytorch_model()
|
||||
from ultralytics.yolo.utils.benchmarks import benchmark
|
||||
overrides = self.model.args.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
|
||||
return run_benchmarks(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
|
||||
return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
|
||||
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
@ -283,6 +296,7 @@ class YOLO:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self._check_pip_update()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
if kwargs.get('cfg'):
|
||||
|
@ -178,7 +178,12 @@ class BasePredictor:
|
||||
self.run_callbacks('on_predict_postprocess_end')
|
||||
|
||||
# visualize, save, write results
|
||||
for i in range(len(im)):
|
||||
n = len(im)
|
||||
for i in range(n):
|
||||
self.results[i].speed = {
|
||||
'preprocess': self.dt[0].dt * 1E3 / n,
|
||||
'inference': self.dt[1].dt * 1E3 / n,
|
||||
'postprocess': self.dt[2].dt * 1E3 / n}
|
||||
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \
|
||||
else (path, im0s.copy())
|
||||
p = Path(p)
|
||||
|
Reference in New Issue
Block a user