From 21ae321bc2744a334788d4c3f617fb0698950b31 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 4 Feb 2023 19:54:34 +0400 Subject: [PATCH] Update YOLOv5 YAMLs to 'u' YAMLs (#800) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CONTRIBUTING.md | 6 +++--- README.md | 6 +++--- README.zh-CN.md | 6 +++--- docker/Dockerfile | 2 +- docs/cfg.md | 2 +- docs/engine.md | 8 +++---- docs/hub.md | 2 +- docs/reference/base_trainer.md | 2 +- ultralytics/__init__.py | 2 +- ultralytics/hub/utils.py | 4 +++- ultralytics/nn/modules.py | 3 ++- ultralytics/yolo/engine/exporter.py | 18 ++++++++++------ ultralytics/yolo/engine/model.py | 4 ++++ ultralytics/yolo/engine/predictor.py | 7 ++++--- ultralytics/yolo/utils/checks.py | 28 +++++++++++++++++++------ ultralytics/yolo/utils/downloads.py | 12 ++--------- ultralytics/yolo/v8/classify/predict.py | 2 +- 17 files changed, 68 insertions(+), 46 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6ecaae7..82b1fce 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -59,18 +59,18 @@ To allow your work to be integrated as seamlessly as possible, we advise you to: ### Docstrings -Not all functions or classes require docstrings but when they do, we follow [google-stlye docstrings format](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings). Here is an example: +Not all functions or classes require docstrings but when they do, we follow [google-style docstrings format](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings). Here is an example: ```python """ - What the function does - performs nms on given detection predictions + What the function does. Performs NMS on given detection predictions. Args: arg1: The description of the 1st argument arg2: The description of the 2nd argument Returns: - What the function returns. Empty if nothing is returned + What the function returns. Empty if nothing is returned. Raises: Exception Class: When and why this exception can be raised by the function. diff --git a/README.md b/README.md index 406be6f..0376d83 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@

- +

@@ -191,7 +191,7 @@ See [Classification Docs](https://docs.ultralytics.com/tasks/classification/) fo ##
Integrations

- +

@@ -220,7 +220,7 @@ See [Classification Docs](https://docs.ultralytics.com/tasks/classification/) fo 🚀 models, and deploy to the real world in a seamless experience. Get started for **Free** now! Also run YOLOv8 models on your iOS or Android device by downloading the [Ultralytics App](https://ultralytics.com/app_install)! - + ##
Contribute
diff --git a/README.zh-CN.md b/README.zh-CN.md index 33cb71f..5650a46 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -1,6 +1,6 @@

- +

@@ -167,7 +167,7 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式 ##
模块集成

- +

@@ -194,7 +194,7 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式 [Ultralytics HUB](https://bit.ly/ultralytics_hub) 是我们⭐ **新**的无代码解决方案,用于可视化数据集,训练 YOLOv8🚀 模型,并以无缝体验方式部署到现实世界。现在开始**免费**! 还可以通过下载 [Ultralytics App](https://ultralytics.com/app_install) 在你的 iOS 或 Android 设备上运行 YOLOv8 模型! - + ##
贡献
diff --git a/docker/Dockerfile b/docker/Dockerfile index 3bbecd5..c49534b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -13,7 +13,7 @@ ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Aria ENV DEBIAN_FRONTEND noninteractive RUN apt update RUN TZ=Etc/UTC apt install -y tzdata -RUN apt install --no-install-recommends -y git zip curl htop libgl1-mesa-glx libglib2.0-0 libpython3-dev gnupg +RUN apt install --no-install-recommends -y gcc git zip curl htop libgl1-mesa-glx libglib2.0-0 libpython3-dev gnupg # RUN alias python=python3 # Create working directory diff --git a/docs/cfg.md b/docs/cfg.md index 5c29fad..aed6883 100644 --- a/docs/cfg.md +++ b/docs/cfg.md @@ -125,7 +125,7 @@ given task. | show | False | show results if possible | | save_txt | False | save results as .txt file | | save_conf | False | save results with confidence scores | -| save_crop | Fasle | save cropped images with results | +| save_crop | False | save cropped images with results | | hide_labels | False | hide labels | | hide_conf | False | hide confidence scores | | vid_stride | False | video frame-rate stride | diff --git a/docs/engine.md b/docs/engine.md index 848ff81..3f90a1f 100644 --- a/docs/engine.md +++ b/docs/engine.md @@ -3,12 +3,12 @@ executors. Let's take a look at the Trainer engine. ## BaseTrainer -BaseTrainer contains the generic boilerplate training routine. It can be customized for any task based over overidding +BaseTrainer contains the generic boilerplate training routine. It can be customized for any task based over overriding the required functions or operations as long the as correct formats are followed. For example, you can support your own -custom model and dataloder by just overriding these functions: +custom model and dataloader by just overriding these functions: * `get_model(cfg, weights)` - The function that builds the model to be trained -* `get_dataloder()` - The function that builds the dataloder +* `get_dataloder()` - The function that builds the dataloader More details and source code can be found in [`BaseTrainer` Reference](reference/base_trainer.md) ## DetectionTrainer @@ -78,6 +78,6 @@ To know more about Callback triggering events and entry point, checkout our Call ## Other engine components -There are other componenets that can be customized similarly like `Validators` and `Predictors` +There are other components that can be customized similarly like `Validators` and `Predictors` See Reference section for more information on these. diff --git a/docs/hub.md b/docs/hub.md index d87f92c..f7a7dbb 100644 --- a/docs/hub.md +++ b/docs/hub.md @@ -80,7 +80,7 @@ training! + Ultralytics mobile app ## ❓ Issues diff --git a/docs/reference/base_trainer.md b/docs/reference/base_trainer.md index 687bdf8..a93af69 100644 --- a/docs/reference/base_trainer.md +++ b/docs/reference/base_trainer.md @@ -1,4 +1,4 @@ -All task Trainers are inherited from `BaseTrainer` class that contains the model training and optimzation routine +All task Trainers are inherited from `BaseTrainer` class that contains the model training and optimization routine boilerplate. You can override any function of these Trainers to suit your needs. --- diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 8575f07..be0c499 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = "8.0.27" +__version__ = "8.0.28" from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.utils import ops diff --git a/ultralytics/hub/utils.py b/ultralytics/hub/utils.py index 2463ab0..eec139f 100644 --- a/ultralytics/hub/utils.py +++ b/ultralytics/hub/utils.py @@ -175,7 +175,9 @@ class Traces: cfg = vars(cfg) # convert type from IterableSimpleNamespace to dict if not all_keys: # filter cfg include_keys = {'task', 'mode'} # always include - cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None) or k in include_keys} + cfg = { + k: (v.split(os.sep)[-1] if isinstance(v, str) and os.sep in v else v) + for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None) or k in include_keys} trace = {'uuid': SETTINGS['uuid'], 'cfg': cfg, 'metadata': self.metadata} # Send a request to the HUB API to sync analytics diff --git a/ultralytics/nn/modules.py b/ultralytics/nn/modules.py index b98ba65..2bbf338 100644 --- a/ultralytics/nn/modules.py +++ b/ultralytics/nn/modules.py @@ -456,4 +456,5 @@ class Classify(nn.Module): def forward(self, x): if isinstance(x, list): x = torch.cat(x, 1) - return self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) + x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) + return x if self.training else x.softmax(1) diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index 5e31237..4c978e8 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -70,7 +70,7 @@ from ultralytics.nn.modules import Detect, Segment from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, guess_model_task from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages -from ultralytics.yolo.data.utils import check_det_dataset +from ultralytics.yolo.data.utils import check_det_dataset, IMAGENET_MEAN, IMAGENET_STD from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get_default_args, yaml_save from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml from ultralytics.yolo.utils.files import file_size @@ -185,8 +185,8 @@ class Exporter: if self.args.half and not coreml and not xml: im, model = im.half(), model.half() # to FP16 shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape - LOGGER.info( - f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") + LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} and " + f"output shape {shape} ({file_size(file):.1f} MB)") # Warnings warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning @@ -384,12 +384,18 @@ class Exporter: LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') f = self.file.with_suffix('.mlmodel') - task = self.model.task + if self.model.task == 'classify': + bias = [-x for x in IMAGENET_MEAN] + scale = 1 / 255 / (sum(IMAGENET_STD) / 3) + classifier_config = ct.ClassifierConfig(list(self.model.names.values())) + else: + bias = [0.0, 0.0, 0.0] + scale = 1 / 255 + classifier_config = None model = iOSModel(self.model, self.im).eval() if self.args.nms else self.model ts = torch.jit.trace(model, self.im, strict=False) # TorchScript model - classifier_config = ct.ClassifierConfig(list(model.names.values())) if task == 'classify' else None ct_model = ct.convert(ts, - inputs=[ct.ImageType('image', shape=self.im.shape, scale=1 / 255, bias=[0, 0, 0])], + inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)], classifier_config=classifier_config) bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None) if bits < 32: diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 5107d73..4762452 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -162,6 +162,8 @@ class YOLO: args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) args.data = data or args.data args.task = self.task + if args.imgsz == DEFAULT_CFG.imgsz: + args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed args.imgsz = check_imgsz(args.imgsz, max_dim=1) validator = self.ValidatorClass(args=args) @@ -180,6 +182,8 @@ class YOLO: overrides.update(kwargs) args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) args.task = self.task + if args.imgsz == DEFAULT_CFG.imgsz: + args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed exporter = Exporter(overrides=args) exporter(model=self.model) diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 462120e..59dd85c 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -120,9 +120,6 @@ class BasePredictor: pass def setup_source(self, source): - if not self.model: - raise Exception("Model not initialized!") - self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size self.dataset = load_inference_source(source=source, transforms=getattr(self.model.model, 'transforms', None), @@ -190,6 +187,10 @@ class BasePredictor: if self.args.verbose: LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms") + # Release assets + if isinstance(self.vid_writer[-1], cv2.VideoWriter): + self.vid_writer[-1].release() # release final video writer + # Print results if self.args.verbose and self.seen: t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index 2fa129f..928b493 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -5,6 +5,7 @@ import inspect import math import os import platform +import re import shutil import urllib from pathlib import Path @@ -67,12 +68,13 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'") # Apply max_dim - if max_dim == 1: - LOGGER.warning(f"WARNING ⚠️ 'train' and 'val' imgsz types must be integer, updating to 'imgsz={max(imgsz)}'. " - f"'predict' and 'export' imgsz may be list or integer, " - f"i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'") + if len(imgsz) > max_dim: + msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \ + "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" + if max_dim != 1: + raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") + LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") imgsz = [max(imgsz)] - # Make image size a multiple of the stride sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] @@ -220,10 +222,24 @@ def check_suffix(file='yolov8n.pt', suffix=('.pt',), msg=''): assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}" +def check_yolov5u_filename(file: str): + # Replace legacy YOLOv5 filenames with updated YOLOv5u filenames + if 'yolov3' in file or 'yolov5' in file and 'u' not in file: + original_file = file + file = re.sub(r"(.*yolov5([nsmlx]))\.", "\\1u.", file) # i.e. yolov5n.pt -> yolov5nu.pt + file = re.sub(r"(.*yolov3(|-tiny|-spp))\.", "\\1u.", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt + if file != original_file: + LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " + f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " + f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n") + return file + + def check_file(file, suffix=''): # Search/download file (if necessary) and return path check_suffix(file, suffix) # optional - file = str(file) # convert to str() + file = str(file) # convert to string + file = check_yolov5u_filename(file) # yolov5n -> yolov5nu if not file or ('://' not in file and Path(file).is_file()): # exists ('://' check required in Windows Python<3.10) return file elif file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py index a340454..622d1ae 100644 --- a/ultralytics/yolo/utils/downloads.py +++ b/ultralytics/yolo/utils/downloads.py @@ -1,7 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license import contextlib -import re import subprocess from itertools import repeat from multiprocessing.pool import ThreadPool @@ -111,6 +110,7 @@ def safe_download(url, def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'): # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc. from ultralytics.yolo.utils import SETTINGS + from ultralytics.yolo.utils.checks import check_yolov5u_filename def github_assets(repository, version='latest'): # Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...]) @@ -121,15 +121,7 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'): # YOLOv3/5u updates file = str(file) - if 'yolov3' in file or 'yolov5' in file and 'u' not in file: - original_file = file - file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt - file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt - if file != original_file: - LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " - f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " - f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n") - + file = check_yolov5u_filename(file) file = Path(file.strip().replace("'", '')) if file.exists(): return str(file) diff --git a/ultralytics/yolo/v8/classify/predict.py b/ultralytics/yolo/v8/classify/predict.py index 4ae40dd..1eed2d7 100644 --- a/ultralytics/yolo/v8/classify/predict.py +++ b/ultralytics/yolo/v8/classify/predict.py @@ -23,7 +23,7 @@ class ClassificationPredictor(BasePredictor): results = [] for i, pred in enumerate(preds): shape = orig_img[i].shape if isinstance(orig_img, list) else orig_img.shape - results.append(Results(probs=pred.softmax(0), orig_shape=shape[:2])) + results.append(Results(probs=pred, orig_shape=shape[:2])) return results def write_results(self, idx, results, batch):