diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8af1964..29a5387 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -73,6 +73,36 @@ jobs: hub.login(key) model = YOLO('https://hub.ultralytics.com/models/' + model_id) model.train() + - name: Test HUB training (Python Usage 3) + shell: python + env: + APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }} + run: | + import os + from pathlib import Path + from ultralytics import YOLO, hub + from ultralytics.yolo.utils import USER_CONFIG_DIR + Path(USER_CONFIG_DIR / 'settings.yaml').unlink() + key = os.environ['APIKEY'] + hub.reset_model(key) + model = YOLO(key) + model.train() + - name: Test HUB training (Python Usage 4) + shell: python + env: + APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }} + run: | + import os + from pathlib import Path + from ultralytics import YOLO, hub + from ultralytics.yolo.utils import USER_CONFIG_DIR + Path(USER_CONFIG_DIR / 'settings.yaml').unlink() + key = os.environ['APIKEY'] + hub.reset_model(key) + key, model_id = key.split('_') + hub.login(key) + model = YOLO(model_id) + model.train() Benchmarks: runs-on: ${{ matrix.os }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0cc5937..ec3a176 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -56,7 +56,7 @@ repos: name: PEP8 - repo: https://github.com/codespell-project/codespell - rev: v2.2.2 + rev: v2.2.4 hooks: - id: codespell args: diff --git a/docs/modes/predict.md b/docs/modes/predict.md index 03bc4fc..30f8743 100644 --- a/docs/modes/predict.md +++ b/docs/modes/predict.md @@ -216,10 +216,20 @@ masks, classification logits, etc.) found in the results object res_plotted = res[0].plot() cv2.imshow("result", res_plotted) ``` +| Argument | Description | +| ----------- | ------------- | +| `conf (bool)` | Whether to plot the detection confidence score. | +| `line_width (float, optional)` | The line width of the bounding boxes. If None, it is scaled to the image size. | +| `font_size (float, optional)` | The font size of the text. If None, it is scaled to the image size. | +| `font (str)` | The font to use for the text. | +| `pil (bool)` | Whether to return the image as a PIL Image. | +| `example (str)` | An example string to display. Useful for indicating the expected format of the output. | +| `img (numpy.ndarray)` | Plot to another image. if not, plot to original image. | +| `labels (bool)` | Whether to plot the label of bounding boxes. | +| `boxes (bool)` | Whether to plot the bounding boxes. | +| `masks (bool)` | Whether to plot the masks. | +| `probs (bool)` | Whether to plot classification probability. | -- `show_conf (bool)`: Show confidence -- `line_width (Float)`: The line width of boxes. Automatically scaled to img size if not provided -- `font_size (Float)`: The font size of . Automatically scaled to img size if not provided ## Streaming Source `for`-loop diff --git a/docs/usage/cfg.md b/docs/usage/cfg.md index 6272476..0367ac4 100644 --- a/docs/usage/cfg.md +++ b/docs/usage/cfg.md @@ -136,8 +136,8 @@ The prediction settings for YOLO models encompass a range of hyperparameters and | `save_txt` | `False` | save results as .txt file | | `save_conf` | `False` | save results with confidence scores | | `save_crop` | `False` | save cropped images with results | -| `hide_labels` | `False` | hide labels | -| `hide_conf` | `False` | hide confidence scores | +| `show_labels` | `True` | show object labels in plots | +| `show_conf` | `True` | show object confidence scores in plots | | `max_det` | `300` | maximum number of detections per image | | `vid_stride` | `False` | video frame-rate stride | | `line_thickness` | `3` | bounding box thickness (pixels) | diff --git a/examples/README.md b/examples/README.md index 9fc542a..6fad2b7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,22 +1,24 @@ -This is a list of real-world applications and walkthroughs. These can be folders of either python files or notebooks . +## Ultralytics YOLOv8 Example Applications -## Ultralytics YOLO example applications +This repository features a collection of real-world applications and walkthroughs, provided as either Python files or notebooks. Explore the examples below to see how YOLOv8 can be integrated into various applications. + +### Ultralytics YOLO Example Applications | Title | Format | Contributor | | ------------------------------------------------------------------------ | ------------------ | --------------------------------------------------- | -| [YOLO ONNX detection Inference with C++](./YOLOv8-CPP-Inference) | C++/ONNX | [Justas Bartnykas](https://github.com/JustasBart) | -| [YOLO OpenCV ONNX detection Python](./YOLOv8-OpenCV-ONNX-Python) | OpenCV/Python/ONNX | [Farid Inawan](https://github.com/frdteknikelektro) | -| [YOLO .Net ONNX detection C#](https://www.nuget.org/packages/Yolov8.Net) | C# .Net | [Samuel Stainback](https://github.com/sstainba) | +| [YOLO ONNX Detection Inference with C++](./YOLOv8-CPP-Inference) | C++/ONNX | [Justas Bartnykas](https://github.com/JustasBart) | +| [YOLO OpenCV ONNX Detection Python](./YOLOv8-OpenCV-ONNX-Python) | OpenCV/Python/ONNX | [Farid Inawan](https://github.com/frdteknikelektro) | +| [YOLO .Net ONNX Detection C#](https://www.nuget.org/packages/Yolov8.Net) | C# .Net | [Samuel Stainback](https://github.com/sstainba) | -## How can you contribute ? +### How to Contribute -We're looking for examples, applications and guides from the community. Here's how you can contribute: +We welcome contributions from the community in the form of examples, applications, and guides. To contribute, please follow these steps: -- Make a PR with `[Example]` prefix in title after adding your project folder in the examples/ folder of the repository -- The project should satisfy these conditions: - - It should use ultralytics framework - - It have a README.md with instructions to run the project - - It should avoid adding large assets or dependencies unless absolutely needed - - The contributor is expected to help out in issues related to their examples +1. Create a pull request (PR) with the `[Example]` prefix in the title, adding your project folder to the `examples/` directory in the repository. +1. Ensure that your project meets the following criteria: + - Utilizes the `ultralytics` package. + - Includes a `README.md` file with instructions on how to run the project. + - Avoids adding large assets or dependencies unless absolutely necessary. + - The contributor is expected to provide support for issues related to their examples. -If you're unsure about any of these requirements, make a PR and we'll happy to guide you +If you have any questions or concerns about these requirements, please submit a PR, and we will be more than happy to guide you. diff --git a/examples/YOLOv8-CPP-Inference/README.md b/examples/YOLOv8-CPP-Inference/README.md index 4eca0ce..548f9b8 100644 --- a/examples/YOLOv8-CPP-Inference/README.md +++ b/examples/YOLOv8-CPP-Inference/README.md @@ -1,17 +1,20 @@ -# yolov8/yolov5 Inference C++ +# YOLOv8/YOLOv5 Inference C++ -Usage: +This example demonstrates how to perform inference using YOLOv8 and YOLOv5 models in C++ with OpenCV's DNN API. -``` -# git clone ultralytics +## Usage + +```commandline +git clone ultralytics +cd ultralytics pip install . cd examples/cpp_ -Add a **yolov8\_.onnx** and/or **yolov5\_.onnx** model(s) to the ultralytics folder. -Edit the **main.cpp** to change the **projectBasePath** to match your user. +# Add a **yolov8\_.onnx** and/or **yolov5\_.onnx** model(s) to the ultralytics folder. +# Edit the **main.cpp** to change the **projectBasePath** to match your user. -Note that by default the CMake file will try and import the CUDA library to be used with the OpenCVs dnn (cuDNN) GPU Inference. -If your OpenCV build does not use CUDA/cuDNN you can remove that import call and run the example on CPU. +# Note that by default the CMake file will try and import the CUDA library to be used with the OpenCVs dnn (cuDNN) GPU Inference. +# If your OpenCV build does not use CUDA/cuDNN you can remove that import call and run the example on CPU. mkdir build cd build @@ -20,24 +23,18 @@ make ./Yolov8CPPInference ``` -To export yolov8 models: +## Exporting YOLOv8 and YOLOv5 Models -``` -yolo export \ -model=yolov8s.pt \ -imgsz=[480,640] \ -format=onnx \ -opset=12 +To export YOLOv8 models: + +```commandline +yolo export model=yolov8s.pt imgsz=480,640 format=onnx opset=12 ``` -To export yolov5 models: +To export YOLOv5 models: -``` -python3 export.py \ ---weights yolov5s.pt \ ---img 480 640 \ ---include onnx \ ---opset 12 +```commandline +python3 export.py --weights yolov5s.pt --img 480 640 --include onnx --opset 12 ``` yolov8s.onnx: @@ -48,10 +45,6 @@ yolov5s.onnx: ![image](https://user-images.githubusercontent.com/40023722/217357005-07464492-d1da-42e3-98a7-fc753f87d5e6.png) -This repository is based on OpenCVs dnn API to run an ONNX exported model of either yolov5/yolov8 (In theory should work -for yolov6 and yolov7 but not tested). Note that for this example the networks are exported as rectangular (640x480) -resolutions, but it would work for any resolution that you export as although you might want to use the letterBox -approach for square images depending on your use-case. +This repository utilizes OpenCV's DNN API to run ONNX exported models of YOLOv5 and YOLOv8. In theory, it should work for YOLOv6 and YOLOv7 as well, but they have not been tested. Note that the example networks are exported with rectangular (640x480) resolutions, but any exported resolution will work. You may want to use the letterbox approach for square images, depending on your use case. -The **main** branch version is based on using Qt as a GUI wrapper the main interest here is the **Inference** class file -which shows how to transpose yolov8 models to work as yolov5 models. +The **main** branch version uses Qt as a GUI wrapper. The primary focus here is the **Inference** class file, which demonstrates how to transpose YOLOv8 models to work as YOLOv5 models. diff --git a/examples/YOLOv8-CPP-Inference/inference.cpp b/examples/YOLOv8-CPP-Inference/inference.cpp index b45830e..12c2607 100644 --- a/examples/YOLOv8-CPP-Inference/inference.cpp +++ b/examples/YOLOv8-CPP-Inference/inference.cpp @@ -83,7 +83,7 @@ std::vector Inference::runInference(const cv::Mat &input) { float confidence = data[4]; - if (confidence >= modelConfidenseThreshold) + if (confidence >= modelConfidenceThreshold) { float *classes_scores = data+5; diff --git a/examples/YOLOv8-CPP-Inference/inference.h b/examples/YOLOv8-CPP-Inference/inference.h index 5763e10..dc6149f 100644 --- a/examples/YOLOv8-CPP-Inference/inference.h +++ b/examples/YOLOv8-CPP-Inference/inference.h @@ -40,7 +40,7 @@ private: cv::Size2f modelShape{}; - float modelConfidenseThreshold {0.25}; + float modelConfidenceThreshold {0.25}; float modelScoreThreshold {0.45}; float modelNMSThreshold {0.50}; diff --git a/examples/YOLOv8-OpenCV-ONNX-Python/main.py b/examples/YOLOv8-OpenCV-ONNX-Python/main.py index acae890..d1f635c 100644 --- a/examples/YOLOv8-OpenCV-ONNX-Python/main.py +++ b/examples/YOLOv8-OpenCV-ONNX-Python/main.py @@ -27,7 +27,7 @@ def main(onnx_model, input_image): image[0:height, 0:width] = original_image scale = length / 640 - blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640)) + blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640), swapRB=True) model.setInput(blob) outputs = model.forward() diff --git a/tests/test_python.py b/tests/test_python.py index 22446a9..ee2a190 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -207,10 +207,10 @@ def test_predict_callback_and_setup(): def test_result(): model = YOLO('yolov8n-seg.pt') res = model([SOURCE, SOURCE]) - res[0].plot(show_conf=False) + res[0].plot(show_conf=False) # raises warning + res[0].plot(conf=True, boxes=False, masks=True) res[0] = res[0].cpu().numpy() print(res[0].path, res[0].masks.masks) - model = YOLO('yolov8n.pt') res = model(SOURCE) res[0].plot() @@ -218,5 +218,5 @@ def test_result(): model = YOLO('yolov8n-cls.pt') res = model(SOURCE) - res[0].plot() + res[0].plot(probs=False) print(res[0].path) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 53e0961..c9addaf 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = '8.0.61' +__version__ = '8.0.62' from ultralytics.hub import start from ultralytics.yolo.engine.model import YOLO diff --git a/ultralytics/yolo/cfg/__init__.py b/ultralytics/yolo/cfg/__init__.py index f81a2ef..83e997d 100644 --- a/ultralytics/yolo/cfg/__init__.py +++ b/ultralytics/yolo/cfg/__init__.py @@ -9,8 +9,8 @@ from types import SimpleNamespace from typing import Dict, List, Union from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR, - IterableSimpleNamespace, __version__, checks, colorstr, get_settings, yaml_load, - yaml_print) + IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, + get_settings, yaml_load, yaml_print) # Define valid tasks and modes MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' @@ -71,7 +71,7 @@ CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic' 'line_thickness', 'workspace', 'nbs', 'save_period') CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr', 'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', - 'save_conf', 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', + 'save_conf', 'save_crop', 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader') @@ -140,6 +140,22 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove return IterableSimpleNamespace(**cfg) +def _handle_deprecation(custom): + """ + Hardcoded function to handle deprecated config keys + """ + + for key in custom.copy().keys(): + if key == 'hide_labels': + deprecation_warn(key, 'show_labels') + custom['show_labels'] = custom.pop('hide_labels') == 'False' + if key == 'hide_conf': + deprecation_warn(key, 'show_conf') + custom['show_conf'] = custom.pop('hide_conf') == 'False' + + return custom + + def check_cfg_mismatch(base: Dict, custom: Dict, e=None): """ This function checks for any mismatched keys between a custom configuration list and a base configuration list. @@ -149,6 +165,7 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None): - custom (Dict): a dictionary of custom configuration options - base (Dict): a dictionary of base configuration options """ + custom = _handle_deprecation(custom) base, custom = (set(x.keys()) for x in (base, custom)) mismatched = [x for x in custom if x not in base] if mismatched: diff --git a/ultralytics/yolo/cfg/default.yaml b/ultralytics/yolo/cfg/default.yaml index cc6b475..f2ab36c 100644 --- a/ultralytics/yolo/cfg/default.yaml +++ b/ultralytics/yolo/cfg/default.yaml @@ -55,8 +55,8 @@ show: False # show results if possible save_txt: False # save results as .txt file save_conf: False # save results with confidence scores save_crop: False # save cropped images with results -hide_labels: False # hide labels -hide_conf: False # hide confidence scores +show_labels: True # show object labels in plots +show_conf: True # show object confidence scores in plots vid_stride: 1 # video frame-rate stride line_thickness: 3 # bounding box thickness (pixels) visualize: False # visualize model features diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 5c47726..017e8e9 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -1,6 +1,7 @@ # Ultralytics YOLO 🚀, GPL-3.0 license import sys +from copy import deepcopy from pathlib import Path from typing import Union @@ -77,7 +78,7 @@ class YOLO: task (Any, optional): Task type for the YOLO model. Defaults to None. """ - self._reset_callbacks() + self.callbacks = deepcopy(callbacks.default_callbacks) self.predictor = None # reuse predictor self.model = None # model object self.trainer = None # trainer object @@ -91,7 +92,7 @@ class YOLO: model = str(model).strip() # strip spaces # Check if Ultralytics HUB model from https://hub.ultralytics.com - if model.startswith('https://hub.ultralytics.com/models/'): + if self.is_hub_model(model): from ultralytics.hub.session import HUBTrainingSession self.session = HUBTrainingSession(model) model = self.session.model_file @@ -112,6 +113,13 @@ class YOLO: name = self.__class__.__name__ raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + @staticmethod + def is_hub_model(model): + return any(( + model.startswith('https://hub.ultralytics.com/models/'), + [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID + (len(model) == 20 and not Path(model).exists() and not any(x in model for x in './\\')))) # MODELID + def _new(self, cfg: str, task=None, verbose=True): """ Initializes a new model and infers the task type from the model definitions. @@ -220,8 +228,7 @@ class YOLO: if source is None: source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") - is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and \ - ('predict' in sys.argv or 'mode=predict' in sys.argv) + is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics') overrides = self.overrides.copy() overrides['conf'] = 0.25 @@ -231,7 +238,7 @@ class YOLO: overrides['save'] = kwargs.get('save', False) # not save files by default if not self.predictor: self.task = overrides.get('task') or self.task - self.predictor = TASK_MAP[self.task][3](overrides=overrides) + self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks) self.predictor.setup_model(model=self.model, verbose=is_cli) else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, overrides) @@ -380,19 +387,17 @@ class YOLO: """ return self.model.transforms if hasattr(self.model, 'transforms') else None - @staticmethod - def add_callback(event: str, func): + def add_callback(self, event: str, func): """ Add callback """ - callbacks.default_callbacks[event].append(func) + self.callbacks[event].append(func) @staticmethod def _reset_ckpt_args(args): include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model return {k: v for k, v in args.items() if k in include} - @staticmethod - def _reset_callbacks(): + def _reset_callbacks(self): for event in callbacks.default_callbacks.keys(): - callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]] + self.callbacks[event] = [callbacks.default_callbacks[event][0]] diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 82905ca..cb24faf 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -75,7 +75,7 @@ class BasePredictor: data_path (str): Path to data. """ - def __init__(self, cfg=DEFAULT_CFG, overrides=None): + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): """ Initializes the BasePredictor class. @@ -104,7 +104,7 @@ class BasePredictor: self.data_path = None self.source_type = None self.batch = None - self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks + self.callbacks = defaultdict(list, _callbacks) if _callbacks else defaultdict(list, callbacks.default_callbacks) callbacks.add_integration_callbacks(self) def preprocess(self, img): diff --git a/ultralytics/yolo/engine/results.py b/ultralytics/yolo/engine/results.py index 1bf2d69..cbaa543 100644 --- a/ultralytics/yolo/engine/results.py +++ b/ultralytics/yolo/engine/results.py @@ -12,7 +12,7 @@ import numpy as np import torch import torchvision.transforms.functional as F -from ultralytics.yolo.utils import LOGGER, SimpleClass, ops +from ultralytics.yolo.utils import LOGGER, SimpleClass, deprecation_warn, ops from ultralytics.yolo.utils.plotting import Annotator, colors from ultralytics.yolo.utils.torch_utils import TORCHVISION_0_10 @@ -65,7 +65,7 @@ class Results(SimpleClass): self.boxes = Boxes(boxes, self.orig_shape) if masks is not None: self.masks = Masks(masks, self.orig_shape) - if boxes is not None: + if probs is not None: self.probs = probs def cpu(self): @@ -100,46 +100,72 @@ class Results(SimpleClass): def keys(self): return [k for k in self._keys if getattr(self, k) is not None] - def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'): + def plot( + self, + conf=True, + line_width=None, + font_size=None, + font='Arial.ttf', + pil=False, + example='abc', + img=None, + labels=True, + boxes=True, + masks=True, + probs=True, + **kwargs # deprecated args TODO: remove support in 8.2 + ): """ Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image. Args: - show_conf (bool): Whether to show the detection confidence score. + conf (bool): Whether to plot the detection confidence score. line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size. font_size (float, optional): The font size of the text. If None, it is scaled to the image size. font (str): The font to use for the text. pil (bool): Whether to return the image as a PIL Image. example (str): An example string to display. Useful for indicating the expected format of the output. + img (numpy.ndarray): Plot to another image. if not, plot to original image. + labels (bool): Whether to plot the label of bounding boxes. + boxes (bool): Whether to plot the bounding boxes. + masks (bool): Whether to plot the masks. + probs (bool): Whether to plot classification probability Returns: (None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned. """ - annotator = Annotator(deepcopy(self.orig_img), line_width, font_size, font, pil, example) - boxes = self.boxes - masks = self.masks - probs = self.probs + # Deprecation warn TODO: remove in 8.2 + if 'show_conf' in kwargs: + deprecation_warn('show_conf', 'conf') + conf = kwargs['show_conf'] + assert type(conf) == bool, '`show_conf` should be of boolean type, i.e, show_conf=True/False' + + annotator = Annotator(deepcopy(self.orig_img if img is None else img), line_width, font_size, font, pil, + example) + pred_boxes, show_boxes = self.boxes, boxes + pred_masks, show_masks = self.masks, masks + pred_probs, show_probs = self.probs, probs names = self.names - hide_labels, hide_conf = False, not show_conf - if boxes is not None: - for d in reversed(boxes): - c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) + if pred_boxes and show_boxes: + for d in reversed(pred_boxes): + c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item()) name = ('' if id is None else f'id:{id} ') + names[c] - label = None if hide_labels else (name if hide_conf else f'{name} {conf:.2f}') + label = (name if not conf else f'{name} {conf:.2f}') if labels else None annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) - if masks is not None: - im = torch.as_tensor(annotator.im, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0) + if pred_masks and show_masks: + im = torch.as_tensor(annotator.im, dtype=torch.float16, device=pred_masks.data.device).permute(2, 0, + 1).flip(0) if TORCHVISION_0_10: - im = F.resize(im.contiguous(), masks.data.shape[1:], antialias=True) / 255 + im = F.resize(im.contiguous(), pred_masks.data.shape[1:], antialias=True) / 255 else: - im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255 - annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im) + im = F.resize(im.contiguous(), pred_masks.data.shape[1:]) / 255 + annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=im) - if probs is not None: + if pred_probs is not None and show_probs: n5 = min(len(names), 5) - top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices - text = f"{', '.join(f'{names[j] if names else j} {probs[j]:.2f}' for j in top5i)}, " + top5i = pred_probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices + text = f"{', '.join(f'{names[j] if names else j} {pred_probs[j]:.2f}' for j in top5i)}, " annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors return np.asarray(annotator.im) if annotator.pil else annotator.im diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index 779d59e..5d5c562 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -624,7 +624,8 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.3'): # Check that settings keys and types match defaults correct = \ - settings.keys() == defaults.keys() \ + settings \ + and settings.keys() == defaults.keys() \ and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \ and check_version(settings['settings_version'], version) if not correct: @@ -646,6 +647,14 @@ def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'): yaml_save(file, SETTINGS) +def deprecation_warn(arg, new_arg, version=None): + if not version: + version = float(__version__[0:3]) + 0.2 # deprecate after 2nd major release + LOGGER.warning( + f'WARNING: `{arg}` is deprecated and will be removed in upcoming major release {version}. Use `{new_arg}` instead' + ) + + # Run below code on yolo/utils init ------------------------------------------------------------------------------------ # Check first-install steps diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py index 4df94b1..a54b6e7 100644 --- a/ultralytics/yolo/v8/detect/predict.py +++ b/ultralytics/yolo/v8/detect/predict.py @@ -70,7 +70,7 @@ class DetectionPredictor(BasePredictor): f.write(('%g ' * len(line)).rstrip() % line + '\n') if self.args.save or self.args.show: # Add bbox to image name = ('' if id is None else f'id:{id} ') + self.model.names[c] - label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}') + label = (f'{name} {conf:.2f}' if self.args.show_conf else name) if self.args.show_labels else None self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.save_crop: save_one_box(d.xyxy, diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index bc5c168..66b35f7 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -84,7 +84,7 @@ class SegmentationPredictor(DetectionPredictor): f.write(('%g ' * len(line)).rstrip() % line + '\n') if self.args.save or self.args.show: # Add bbox to image name = ('' if id is None else f'id:{id} ') + self.model.names[c] - label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}') + label = (f'{name} {conf:.2f}' if self.args.show_conf else name) if self.args.show_labels else None if self.args.boxes: self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.save_crop: