From a82ee2c779b5556f18532deb4ba85638879171b7 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 25 Feb 2023 09:24:14 -0800 Subject: [PATCH] `ultralytics 8.0.46` TFLite and Benchmarks updates (#1141) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/ci.yaml | 20 +++++++++------- ultralytics/__init__.py | 2 +- ultralytics/yolo/cfg/__init__.py | 20 +++++++++------- ultralytics/yolo/data/dataset.py | 36 ++++++++++++++++++++++------ ultralytics/yolo/engine/exporter.py | 8 +++---- ultralytics/yolo/engine/model.py | 22 +++++++++++++---- ultralytics/yolo/engine/predictor.py | 7 +++++- ultralytics/yolo/utils/__init__.py | 18 +------------- ultralytics/yolo/utils/benchmarks.py | 29 +++++++++++----------- ultralytics/yolo/utils/checks.py | 26 ++++++++++++++++++++ ultralytics/yolo/utils/dist.py | 16 +++++++------ 11 files changed, 130 insertions(+), 74 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8e72e85..d739fed 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -55,18 +55,20 @@ jobs: - name: Benchmark DetectionModel shell: python run: | - from ultralytics.yolo.utils.benchmarks import run_benchmarks - run_benchmarks(model='${{ matrix.model }}.pt', imgsz=160, half=False, hard_fail=False) + from ultralytics.yolo.utils.benchmarks import benchmark + benchmark(model='${{ matrix.model }}.pt', imgsz=160, half=False, hard_fail=0.20) - name: Benchmark SegmentationModel shell: python run: | - from ultralytics.yolo.utils.benchmarks import run_benchmarks - run_benchmarks(model='${{ matrix.model }}-seg.pt', imgsz=160, half=False, hard_fail=False) + from ultralytics.yolo.utils.benchmarks import benchmark + benchmark(model='${{ matrix.model }}-seg.pt', imgsz=160, half=False, hard_fail=0.14) - name: Benchmark ClassificationModel shell: python run: | - from ultralytics.yolo.utils.benchmarks import run_benchmarks - run_benchmarks(model='${{ matrix.model }}-cls.pt', imgsz=160, half=False, hard_fail=False) + from ultralytics.yolo.utils.benchmarks import benchmark + benchmark(model='${{ matrix.model }}-cls.pt', imgsz=160, half=False, hard_fail=0.70) + - name: Benchmark Summary + run: cat benchmarks.log Tests: timeout-minutes: 60 @@ -88,10 +90,10 @@ jobs: - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - - name: Get cache dir - # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow + - name: Get cache dir # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow id: pip-cache - run: echo "::set-output name=dir::$(pip cache dir)" + run: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + shell: bash # for Windows compatibility - name: Cache pip uses: actions/cache@v3 with: diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index fc6f2a4..1f37bbc 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = '8.0.45' +__version__ = '8.0.46' from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.utils.checks import check_yolo as checks diff --git a/ultralytics/yolo/cfg/__init__.py b/ultralytics/yolo/cfg/__init__.py index 12e6c0c..2fb9895 100644 --- a/ultralytics/yolo/cfg/__init__.py +++ b/ultralytics/yolo/cfg/__init__.py @@ -254,8 +254,8 @@ def entrypoint(debug=''): else: check_cfg_mismatch(full_args_dict, {a: ''}) - # Defaults - task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100') + # Check keys + check_cfg_mismatch(full_args_dict, overrides) # Mode mode = overrides.get('mode', None) @@ -279,11 +279,12 @@ def entrypoint(debug=''): model = YOLO(model) # Task - task = overrides.get('task', None) - if task is not None and task not in TASKS: - raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") - else: - model.task = task + task = overrides.get('task', model.task) + if task is not None: + if task not in TASKS: + raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") + else: + model.task = task # Mode if mode in {'predict', 'track'} and 'source' not in overrides: @@ -292,8 +293,9 @@ def entrypoint(debug=''): LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") elif mode in ('train', 'val'): if 'data' not in overrides: - overrides['data'] = task2data.get(overrides['task'], DEFAULT_CFG.data) - LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using {model.task} default 'data={overrides['data']}'.") + task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100') + overrides['data'] = task2data.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) + LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.") elif mode == 'export': if 'format' not in overrides: overrides['format'] = DEFAULT_CFG.format or 'torchscript' diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py index 7e132cb..76a3a32 100644 --- a/ultralytics/yolo/data/dataset.py +++ b/ultralytics/yolo/data/dataset.py @@ -16,10 +16,28 @@ from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image class YOLODataset(BaseDataset): cache_version = '1.0.1' # dataset labels *.cache version, >= 1.0.0 for YOLOv8 rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4] - """YOLO Dataset. + """ + Dataset class for loading images object detection and/or segmentation labels in YOLO format. + Args: - img_path (str): image path. - prefix (str): prefix. + img_path (str): path to the folder containing images. + imgsz (int): image size (default: 640). + cache (bool): if True, a cache file of the labels is created to speed up future creation of dataset instances + (default: False). + augment (bool): if True, data augmentation is applied (default: True). + hyp (dict): hyperparameters to apply data augmentation (default: None). + prefix (str): prefix to print in log messages (default: ''). + rect (bool): if True, rectangular training is used (default: False). + batch_size (int): size of batches (default: None). + stride (int): stride (default: 32). + pad (float): padding (default: 0.0). + single_cls (bool): if True, single class training is used (default: False). + use_segments (bool): if True, segmentation masks are used as labels (default: False). + use_keypoints (bool): if True, keypoints are used as labels (default: False). + names (list): class names (default: None). + + Returns: + A PyTorch dataset object that can be used for training an object detection or segmentation model. """ def __init__(self, @@ -44,7 +62,12 @@ class YOLODataset(BaseDataset): super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls) def cache_labels(self, path=Path('./labels.cache')): - # Cache dataset labels, check images and read shapes + """Cache dataset labels, check images and read shapes. + Args: + path (Path): path where to save the cache file (default: Path('./labels.cache')). + Returns: + (dict): labels. + """ x = {'labels': []} nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages desc = f'{self.prefix}Scanning {path.parent / path.stem}...' @@ -119,9 +142,8 @@ class YOLODataset(BaseDataset): self.im_files = [lb['im_file'] for lb in labels] # update im_files # Check if the dataset is all boxes or all segments - len_cls = sum(len(lb['cls']) for lb in labels) - len_boxes = sum(len(lb['bboxes']) for lb in labels) - len_segments = sum(len(lb['segments']) for lb in labels) + lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels) + len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) if len_segments and len_boxes != len_segments: LOGGER.warning( f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, ' diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index 5eaf9a9..886a0b1 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -294,7 +294,7 @@ class Exporter: # YOLOv8 ONNX export requirements = ['onnx>=1.12.0'] if self.args.simplify: - requirements += ['onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'] + requirements += ['onnxsim>=0.4.17', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'] check_requirements(requirements) import onnx # noqa @@ -513,8 +513,8 @@ class Exporter: cuda = torch.cuda.is_available() check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}") import tensorflow as tf # noqa - check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support', - 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'), + check_requirements(('onnx', 'onnx2tf>=1.7.7', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.17', 'onnx_graphsurgeon>=0.3.26', + 'tflite_support', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'), cmds='--extra-index-url https://pypi.ngc.nvidia.com') LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') @@ -529,7 +529,7 @@ class Exporter: # Export to TF int8 = '-oiqt -qt per-tensor' if self.args.int8 else '' - cmd = f'onnx2tf -i {f_onnx} -o {f} --non_verbose {int8}' + cmd = f'onnx2tf -i {f_onnx} -o {f} -nuo --non_verbose {int8}' LOGGER.info(f'\n{prefix} running {cmd}') subprocess.run(cmd, shell=True) yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 585133c..2e65836 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -9,8 +9,9 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat guess_model_task, nn) from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.engine.exporter import Exporter -from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, yaml_load -from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml +from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks, + is_git_dir, is_pip_package, yaml_load) +from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update, check_yaml from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS from ultralytics.yolo.utils.torch_utils import smart_inference_mode @@ -150,6 +151,13 @@ class YOLO: f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only " f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.") + def _check_pip_update(self): + """ + Inform user of ultralytics package update availability + """ + if is_pip_package(): + check_pip_update() + def reset(self): """ Resets the model modules. @@ -189,6 +197,10 @@ class YOLO: Returns: (List[ultralytics.yolo.engine.results.Results]): The prediction results. """ + if source is None: + source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' + LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") + overrides = self.overrides.copy() overrides['conf'] = 0.25 overrides.update(kwargs) # prefer kwargs @@ -251,11 +263,12 @@ class YOLO: Args: **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs """ - from ultralytics.yolo.utils.benchmarks import run_benchmarks + self._check_is_pytorch_model() + from ultralytics.yolo.utils.benchmarks import benchmark overrides = self.model.args.copy() overrides.update(kwargs) overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults - return run_benchmarks(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device']) + return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device']) def export(self, **kwargs): """ @@ -283,6 +296,7 @@ class YOLO: **kwargs (Any): Any number of arguments representing the training configuration. """ self._check_is_pytorch_model() + self._check_pip_update() overrides = self.overrides.copy() overrides.update(kwargs) if kwargs.get('cfg'): diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 2b56421..17eee83 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -178,7 +178,12 @@ class BasePredictor: self.run_callbacks('on_predict_postprocess_end') # visualize, save, write results - for i in range(len(im)): + n = len(im) + for i in range(n): + self.results[i].speed = { + 'preprocess': self.dt[0].dt * 1E3 / n, + 'inference': self.dt[1].dt * 1E3 / n, + 'postprocess': self.dt[2].dt * 1E3 / n} p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \ else (path, im0s.copy()) p = Path(p) diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index 3bbf6c1..e0ad000 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -354,22 +354,6 @@ def get_git_branch(): return None # if not git dir or on error -def get_latest_pypi_version(package_name='ultralytics'): - """ - Returns the latest version of a PyPI package without downloading or installing it. - - Parameters: - package_name (str): The name of the package to find the latest version for. - - Returns: - str: The latest version of the package. - """ - response = requests.get(f'https://pypi.org/pypi/{package_name}/json') - if response.status_code == 200: - return response.json()['info']['version'] - return None - - def get_default_args(func): """Returns a dictionary of default arguments for a function. @@ -611,7 +595,7 @@ def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'): # Run below code on yolo/utils init ------------------------------------------------------------------------------------ # Set logger -set_logging(LOGGING_NAME) # run before defining LOGGER +set_logging(LOGGING_NAME, verbose=VERBOSE) # run before defining LOGGER LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.) if WINDOWS: for fn in LOGGER.info, LOGGER.warning: diff --git a/ultralytics/yolo/utils/benchmarks.py b/ultralytics/yolo/utils/benchmarks.py index e0f4e92..993e6d1 100644 --- a/ultralytics/yolo/utils/benchmarks.py +++ b/ultralytics/yolo/utils/benchmarks.py @@ -37,11 +37,7 @@ from ultralytics.yolo.utils.files import file_size from ultralytics.yolo.utils.torch_utils import select_device -def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', - imgsz=640, - half=False, - device='cpu', - hard_fail=False): +def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, half=False, device='cpu', hard_fail=0.30): device = select_device(device, verbose=False) if isinstance(model, (str, Path)): model = YOLO(model) @@ -52,6 +48,7 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', try: assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML + assert i != 11 or model.task != 'classify', 'paddle-classify bug' if 'cpu' in device.type: assert cpu, 'inference not supported on CPU' @@ -85,26 +82,28 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)]) except Exception as e: if hard_fail: - assert type(e) is AssertionError, f'Benchmark --hard-fail for {name}: {e}' + assert type(e) is AssertionError, f'Benchmark hard_fail for {name}: {e}' LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}') y.append([name, '❌', None, None, None]) # mAP, t_inference # Print results - LOGGER.info('\n') check_yolo(device=device) # print system info - c = ['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'] if map else ['Format', 'Export', '', ''] + c = ['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'] df = pd.DataFrame(y, columns=c) - LOGGER.info(f'\nBenchmarks complete for {Path(model.ckpt_path).name} on {data} at imgsz={imgsz} ' - f'({time.time() - t0:.2f}s)') - LOGGER.info(str(df if map else df.iloc[:, :2])) - if hard_fail and isinstance(hard_fail, str): + name = Path(model.ckpt_path).name + s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n' + LOGGER.info(s) + with open('benchmarks.log', 'a') as f: + f.write(s) + + if hard_fail and isinstance(hard_fail, float): metrics = df[key].array # values to compare to floor - floor = eval(hard_fail) # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n - assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: metric < floor {floor}' + floor = hard_fail # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n + assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: one or more metric(s) < floor {floor}' return df if __name__ == '__main__': - run_benchmarks() + benchmark() diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index 0ee5952..e40f787 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -16,6 +16,7 @@ import cv2 import numpy as np import pkg_resources as pkg import psutil +import requests import torch from matplotlib import font_manager @@ -117,6 +118,31 @@ def check_version(current: str = '0.0.0', return result +def check_latest_pypi_version(package_name='ultralytics'): + """ + Returns the latest version of a PyPI package without downloading or installing it. + + Parameters: + package_name (str): The name of the package to find the latest version for. + + Returns: + str: The latest version of the package. + """ + response = requests.get(f'https://pypi.org/pypi/{package_name}/json') + if response.status_code == 200: + return response.json()['info']['version'] + return None + + +def check_pip_update(): + from ultralytics import __version__ + latest = check_latest_pypi_version() + latest = '9.0.0' + if pkg.parse_version(__version__) < pkg.parse_version(latest): + LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 ' + f"Update with 'pip install -U ultralytics'") + + def check_font(font='Arial.ttf'): """ Find font locally or download to user's configuration directory if it does not already exist. diff --git a/ultralytics/yolo/utils/dist.py b/ultralytics/yolo/utils/dist.py index 47ad6f7..5a49819 100644 --- a/ultralytics/yolo/utils/dist.py +++ b/ultralytics/yolo/utils/dist.py @@ -1,10 +1,12 @@ # Ultralytics YOLO 🚀, GPL-3.0 license import os +import re import shutil import socket import sys import tempfile +from pathlib import Path from . import USER_CONFIG_DIR from .torch_utils import TORCH_1_9 @@ -22,12 +24,12 @@ def find_free_network_port() -> int: def generate_ddp_file(trainer): - import_path = '.'.join(str(trainer.__class__).split('.')[1:-1]) + module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1) content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__": - from ultralytics.{import_path} import {trainer.__class__.__name__} + from {module} import {name} - trainer = {trainer.__class__.__name__}(cfg=cfg) + trainer = {name}(cfg=cfg) trainer.train()''' (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) with tempfile.NamedTemporaryFile(prefix='_temp_', @@ -41,12 +43,12 @@ def generate_ddp_file(trainer): def generate_ddp_command(world_size, trainer): - import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 - file = os.path.abspath(sys.argv[0]) - using_cli = not file.endswith('.py') + import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 if not trainer.resume: shutil.rmtree(trainer.save_dir) # remove the save_dir - if using_cli: + file = str(Path(sys.argv[0]).resolve()) + safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters + if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI file = generate_ddp_file(trainer) dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' port = find_free_network_port()