`ultralytics 8.0.59` new MLFlow and feature updates (#1720)

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: St. HeMeow <sheng.heyang@gmail.com>
Co-authored-by: Danny Kim <imbird0312@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Torge Kummerow <CySlider@users.noreply.github.com>
Co-authored-by: dankernel <dkdkernel@gmail.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
Co-authored-by: Roshanlal <roshanlaladchitre103@gmail.com>
Co-authored-by: Lorenzo Mammana <lorenzo.mammana@orobix.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent ccb6419835
commit e7876e1ba9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,7 @@ on:
push:
branches: [main]
pull_request:
branches: [main, updates]
branches: [main]
schedule:
- cron: '0 0 * * *' # runs at 00:00 UTC every day

@ -5,9 +5,9 @@ name: Check Broken links
on:
push:
branches: [na]
branches: [main]
pull_request:
branches: [na]
branches: [main]
workflow_dispatch:
schedule:
- cron: '0 0 * * *' # runs at 00:00 UTC every day

@ -9,7 +9,7 @@ jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v7
- uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}

@ -152,7 +152,8 @@ operations are cached, meaning they're only calculated once per object, and thos
```python
results = model(inputs)
masks = results[0].masks # Masks object
masks.segments # bounding coordinates of masks, List[segment] * N
masks.xy # x, y segments (pixels), List[segment] * N
masks.xyn # x, y segments (normalized), List[segment] * N
masks.data # raw masks tensor, (N, H, W) or masks.masks
```
@ -185,3 +186,47 @@ masks, classification logits, etc.) found in the results object
- `show_conf (bool)`: Show confidence
- `line_width (Float)`: The line width of boxes. Automatically scaled to img size if not provided
- `font_size (Float)`: The font size of . Automatically scaled to img size if not provided
## Streaming Source `for`-loop
Here's a Python script using OpenCV (cv2) and YOLOv8 to run inference on video frames. This script assumes you have already installed the necessary packages (opencv-python and ultralytics).
!!! example "Streaming for-loop"
```python
import cv2
from ultralytics import YOLO
# Load the YOLOv8 model
model = YOLO('yolov8n.pt')
# Open the video file
video_path = "path/to/your/video/file.mp4"
cap = cv2.VideoCapture(video_path)
# Loop through the video frames
while cap.isOpened():
# Read a frame from the video
success, frame = cap.read()
if success:
# Run YOLOv8 inference on the frame
results = model(frame)
# Visualize the results on the frame
annotated_frame = results[0].plot()
# Display the annotated frame
cv2.imshow("YOLOv8 Inference", annotated_frame)
# Break the loop if 'q' is pressed
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:
# Break the loop if the end of the video is reached
break
# Release the video capture object and close the display window
cap.release()
cv2.destroyAllWindows()
```

@ -73,7 +73,7 @@ task.
| `deterministic` | `True` | whether to enable deterministic mode |
| `single_cls` | `False` | train multi-class data as single-class |
| `image_weights` | `False` | use weighted image selection for training |
| `rect` | `False` | support rectangular training |
| `rect` | `False` | rectangular training with each batch collated for minimum padding |
| `cos_lr` | `False` | use cosine learning rate scheduler |
| `close_mosaic` | `10` | disable mosaic augmentation for final 10 epochs |
| `resume` | `False` | resume training from last checkpoint |

@ -62,7 +62,7 @@ validation dataset and to detect and prevent overfitting.
| `device` | `None` | device to run on, i.e. cuda device=0/1/2/3 or device=cpu |
| `dnn` | `False` | use OpenCV DNN for ONNX inference |
| `plots` | `False` | show plots during training |
| `rect` | `False` | support rectangular evaluation |
| `rect` | `False` | rectangular val with each batch collated for minimum padding |
| `split` | `val` | dataset split to use for validation, i.e. 'val', 'test' or 'train' |
## Export Formats

@ -12,75 +12,74 @@ In this example, we want to return the original frame with each result object. H
```python
def on_predict_batch_end(predictor):
# results -> List[batch_size]
# Retrieve the batch data
_, _, im0s, _, _ = predictor.batch
# Ensure that im0s is a list
im0s = im0s if isinstance(im0s, list) else [im0s]
# Combine the prediction results with the corresponding frames
predictor.results = zip(predictor.results, im0s)
# Create a YOLO model instance
model = YOLO(f'yolov8n.pt')
# Add the custom callback to the model
model.add_callback("on_predict_batch_end", on_predict_batch_end)
# Iterate through the results and frames
for (result, frame) in model.track/predict():
pass
```
## All callbacks
Here are all supported callbacks.
### Trainer
`on_pretrain_routine_start`
`on_pretrain_routine_end`
`on_train_start`
`on_train_epoch_start`
`on_train_batch_start`
`optimizer_step`
`on_before_zero_grad`
`on_train_batch_end`
`on_train_epoch_end`
`on_fit_epoch_end`
`on_model_save`
`on_train_end`
`on_params_update`
`teardown`
### Validator
`on_val_start`
`on_val_batch_start`
Here are all supported callbacks. See callbacks [source code](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/utils/callbacks/base.py) for additional details.
`on_val_batch_end`
`on_val_end`
### Trainer Callbacks
### Predictor
| Callback | Description |
|-----------------------------|---------------------------------------------------------|
| `on_pretrain_routine_start` | Triggered at the beginning of pre-training routine |
| `on_pretrain_routine_end` | Triggered at the end of pre-training routine |
| `on_train_start` | Triggered when the training starts |
| `on_train_epoch_start` | Triggered at the start of each training epoch |
| `on_train_batch_start` | Triggered at the start of each training batch |
| `optimizer_step` | Triggered during the optimizer step |
| `on_before_zero_grad` | Triggered before gradients are zeroed |
| `on_train_batch_end` | Triggered at the end of each training batch |
| `on_train_epoch_end` | Triggered at the end of each training epoch |
| `on_fit_epoch_end` | Triggered at the end of each fit epoch |
| `on_model_save` | Triggered when the model is saved |
| `on_train_end` | Triggered when the training process ends |
| `on_params_update` | Triggered when model parameters are updated |
| `teardown` | Triggered when the training process is being cleaned up |
`on_predict_start`
`on_predict_batch_start`
### Validator Callbacks
`on_predict_postprocess_end`
| Callback | Description |
|----------------------|-------------------------------------------------|
| `on_val_start` | Triggered when the validation starts |
| `on_val_batch_start` | Triggered at the start of each validation batch |
| `on_val_batch_end` | Triggered at the end of each validation batch |
| `on_val_end` | Triggered when the validation ends |
`on_predict_batch_end`
`on_predict_end`
### Predictor Callbacks
### Exporter
| Callback | Description |
|------------------------------|---------------------------------------------------|
| `on_predict_start` | Triggered when the prediction process starts |
| `on_predict_batch_start` | Triggered at the start of each prediction batch |
| `on_predict_postprocess_end` | Triggered at the end of prediction postprocessing |
| `on_predict_batch_end` | Triggered at the end of each prediction batch |
| `on_predict_end` | Triggered when the prediction process ends |
`on_export_start`
### Exporter Callbacks
`on_export_end`
| Callback | Description |
|-------------------|------------------------------------------|
| `on_export_start` | Triggered when the export process starts |
| `on_export_end` | Triggered when the export process ends |

@ -12,6 +12,18 @@ YOLOv8 'yolo' CLI commands use the following syntax:
yolo TASK MODE ARGS
```
=== "Python"
```python
from ultralytics import YOLO
# Load a YOLOv8 model from a pre-trained weights file
model = YOLO('yolov8n.pt')
# Run MODE mode using the custom arguments ARGS (guess TASK)
model.MODE(ARGS)
```
Where:
- `TASK` (optional) is one of `[detect, segment, classify, pose]`. If it is not passed explicitly YOLOv8 will try to
@ -36,6 +48,8 @@ differ in the type of output they produce and the specific problem they are desi
|--------|------------|-------------------------------------------------|
| `task` | `'detect'` | YOLO task, i.e. detect, segment, classify, pose |
[Tasks Guide](../tasks/index.md){ .md-button .md-button--primary}
#### Modes
YOLO models can be used in different modes depending on the specific problem you are trying to solve. These modes
@ -52,14 +66,11 @@ include:
|--------|-----------|---------------------------------------------------------------|
| `mode` | `'train'` | YOLO mode, i.e. train, val, predict, export, track, benchmark |
### Training
[Modes Guide](../modes/index.md){ .md-button .md-button--primary}
Training settings for YOLO models refer to the various hyperparameters and configurations used to train the model on a
dataset. These settings can affect the model's performance, speed, and accuracy. Some common YOLO training settings
include the batch size, learning rate, momentum, and weight decay. Other factors that may affect the training process
include the choice of optimizer, the choice of loss function, and the size and composition of the training dataset. It
is important to carefully tune and experiment with these settings to achieve the best possible performance for a given
task.
## Train
The training settings for YOLO models encompass various hyperparameters and configurations used during the training process. These settings influence the model's performance, speed, and accuracy. Key training settings include batch size, learning rate, momentum, and weight decay. Additionally, the choice of optimizer, loss function, and training dataset composition can impact the training process. Careful tuning and experimentation with these settings are crucial for optimizing performance.
| Key | Value | Description |
|-------------------|----------|-----------------------------------------------------------------------------|
@ -84,7 +95,7 @@ task.
| `deterministic` | `True` | whether to enable deterministic mode |
| `single_cls` | `False` | train multi-class data as single-class |
| `image_weights` | `False` | use weighted image selection for training |
| `rect` | `False` | support rectangular training |
| `rect` | `False` | rectangular training with each batch collated for minimum padding |
| `cos_lr` | `False` | use cosine learning rate scheduler |
| `close_mosaic` | `10` | disable mosaic augmentation for final 10 epochs |
| `resume` | `False` | resume training from last checkpoint |
@ -107,15 +118,11 @@ task.
| `dropout` | `0.0` | use dropout regularization (classify train only) |
| `val` | `True` | validate/test during training |
### Prediction
[Train Guide](../modes/train.md){ .md-button .md-button--primary}
## Predict
Prediction settings for YOLO models refer to the various hyperparameters and configurations used to make predictions
with the model on new data. These settings can affect the model's performance, speed, and accuracy. Some common YOLO
prediction settings include the confidence threshold, non-maximum suppression (NMS) threshold, and the number of classes
to consider. Other factors that may affect the prediction process include the size and format of the input data, the
presence of additional features such as masks or multiple labels per box, and the specific task the model is being used
for. It is important to carefully tune and experiment with these settings to achieve the best possible performance for a
given task.
The prediction settings for YOLO models encompass a range of hyperparameters and configurations that influence the model's performance, speed, and accuracy during inference on new data. Careful tuning and experimentation with these settings are essential to achieve optimal performance for a specific task. Key settings include the confidence threshold, Non-Maximum Suppression (NMS) threshold, and the number of classes considered. Additional factors affecting the prediction process are input data size and format, the presence of supplementary features such as masks or multiple labels per box, and the particular task the model is employed for.
| Key | Value | Description |
|------------------|------------------------|----------------------------------------------------------|
@ -141,15 +148,11 @@ given task.
| `classes` | `None` | filter results by class, i.e. class=0, or class=[0,2,3] |
| `boxes` | `True` | Show boxes in segmentation predictions |
### Validation
[Predict Guide](../modes/predict.md){ .md-button .md-button--primary}
## Val
Validation settings for YOLO models refer to the various hyperparameters and configurations used to
evaluate the model's performance on a validation dataset. These settings can affect the model's performance, speed, and
accuracy. Some common YOLO validation settings include the batch size, the frequency with which validation is performed
during training, and the metrics used to evaluate the model's performance. Other factors that may affect the validation
process include the size and composition of the validation dataset and the specific task the model is being used for. It
is important to carefully tune and experiment with these settings to ensure that the model is performing well on the
validation dataset and to detect and prevent overfitting.
The val (validation) settings for YOLO models involve various hyperparameters and configurations used to evaluate the model's performance on a validation dataset. These settings influence the model's performance, speed, and accuracy. Common YOLO validation settings include batch size, validation frequency during training, and performance evaluation metrics. Other factors affecting the validation process include the validation dataset's size and composition, as well as the specific task the model is employed for. Careful tuning and experimentation with these settings are crucial to ensure optimal performance on the validation dataset and detect and prevent overfitting.
| Key | Value | Description |
|---------------|---------|--------------------------------------------------------------------|
@ -162,19 +165,14 @@ validation dataset and to detect and prevent overfitting.
| `device` | `None` | device to run on, i.e. cuda device=0/1/2/3 or device=cpu |
| `dnn` | `False` | use OpenCV DNN for ONNX inference |
| `plots` | `False` | show plots during training |
| `rect` | `False` | support rectangular evaluation |
| `rect` | `False` | rectangular val with each batch collated for minimum padding |
| `split` | `val` | dataset split to use for validation, i.e. 'val', 'test' or 'train' |
### Export
[Val Guide](../modes/val.md){ .md-button .md-button--primary}
Export settings for YOLO models refer to the various configurations and options used to save or
export the model for use in other environments or platforms. These settings can affect the model's performance, size,
and compatibility with different systems. Some common YOLO export settings include the format of the exported model
file (e.g. ONNX, TensorFlow SavedModel), the device on which the model will be run (e.g. CPU, GPU), and the presence of
additional features such as masks or multiple labels per box. Other factors that may affect the export process include
the specific task the model is being used for and the requirements or constraints of the target environment or platform.
It is important to carefully consider and configure these settings to ensure that the exported model is optimized for
the intended use case and can be used effectively in the target environment.
## Export
Export settings for YOLO models encompass configurations and options related to saving or exporting the model for use in different environments or platforms. These settings can impact the model's performance, size, and compatibility with various systems. Key export settings include the exported model file format (e.g., ONNX, TensorFlow SavedModel), the target device (e.g., CPU, GPU), and additional features such as masks or multiple labels per box. The export process may also be affected by the model's specific task and the requirements or constraints of the destination environment or platform. It is crucial to thoughtfully configure these settings to ensure the exported model is optimized for the intended use case and functions effectively in the target environment.
| Key | Value | Description |
|-------------|-----------------|------------------------------------------------------|
@ -190,7 +188,9 @@ the intended use case and can be used effectively in the target environment.
| `workspace` | `4` | TensorRT: workspace size (GB) |
| `nms` | `False` | CoreML: add NMS |
### Augmentation
[Export Guide](../modes/export.md){ .md-button .md-button--primary}
## Augmentation
Augmentation settings for YOLO models refer to the various transformations and modifications
applied to the training data to increase the diversity and size of the dataset. These settings can affect the model's
@ -217,7 +217,7 @@ ensure that the augmented dataset is diverse and representative enough to train
| `mixup` | 0.0 | image mixup (probability) |
| `copy_paste` | 0.0 | segment copy-paste (probability) |
### Logging, checkpoints, plotting and file management
## Logging, checkpoints, plotting and file management
Logging, checkpoints, plotting, and file management are important considerations when training a YOLO model.

@ -61,7 +61,7 @@ Where:
- `TASK` (optional) is one of `[detect, segment, classify]`. If it is not passed explicitly YOLOv8 will try to guess
the `TASK` from the model type.
- `MODE` (required) is one of `[train, val, predict, export]`
- `MODE` (required) is one of `[train, val, predict, export, track]`
- `ARGS` (optional) are any number of custom `arg=value` pairs like `imgsz=320` that override defaults.
For a full list of available `ARGS` see the [Configuration](cfg.md) page and `defaults.yaml`
GitHub [source](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/cfg/default.yaml).

@ -150,7 +150,8 @@ predicts the classes and locations of objects in the input images or videos.
# segmentation
result.masks.masks # masks, (N, H, W)
result.masks.segments # bounding coordinates of masks, List[segment] * N
result.masks.xy # x,y segments (pixels), List[segment] * N
result.masks.xyn # x,y segments (normalized), List[segment] * N
# classification
result.probs # cls prob, (num_class, )

@ -25,7 +25,7 @@ Creating a custom model to detect your objects is an iterative process of collec
YOLOv5 models must be trained on labelled data in order to learn classes of objects in that data. There are two options for creating your dataset before you start training:
<details markdown>
<summary>Use <a href="https://roboflow.com/?ref=ultralytics">Roboflow</a> to manage your dataset in YOLO format</summary>
<summary>Use Roboflow to manage your dataset in YOLO format</summary>
### 1.1 Collect Images
@ -102,7 +102,7 @@ names:
### 1.2 Create Labels
After using a tool like [Roboflow Annotate](https://roboflow.com/annotate?ref=ultralytics) to label your images, export your labels to **YOLO format**, with one `*.txt` file per image (if no objects in image, no `*.txt` file is required). The `*.txt` file specifications are:
After using an annotation tool to label your images, export your labels to **YOLO format**, with one `*.txt` file per image (if no objects in image, no `*.txt` file is required). The `*.txt` file specifications are:
- One row per object
- Each row is `class x_center y_center width height` format.

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.58'
__version__ = '8.0.59'
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils.checks import check_yolo as checks

@ -14,11 +14,7 @@ def start(key=''):
Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY')
"""
auth = Auth(key)
if not auth.get_state():
model_id = request_api_key(auth)
else:
_, model_id = split_key(key)
model_id = split_key(key)[1] if auth.get_state() else request_api_key(auth)
if not model_id:
raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌'))
@ -36,7 +32,8 @@ def request_api_key(auth, max_attempts=3):
import getpass
for attempts in range(max_attempts):
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
input_key = getpass.getpass('Enter your Ultralytics HUB API key:\n')
input_key = getpass.getpass(
'Enter your Ultralytics API Key from https://hub.ultralytics.com/settings?tab=api+keys:\n')
auth.api_key, model_id = split_key(input_key)
if auth.authenticate():

@ -12,9 +12,9 @@ from random import random
import requests
from tqdm import tqdm
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING,
TQDM_BAR_FORMAT, TryExcept, __version__, colorstr, emojis, get_git_origin_url,
is_colab, is_git_dir, is_pip_package)
from ultralytics.yolo.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM_BAR_FORMAT,
TryExcept, __version__, colorstr, emojis, get_git_origin_url, is_colab, is_git_dir,
is_pip_package)
PREFIX = colorstr('Ultralytics HUB: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
@ -76,7 +76,7 @@ def split_key(key=''):
error_string = emojis(f'{PREFIX}Invalid API key ⚠️\n') # error string
if not key:
key = getpass.getpass('Enter model key: ')
sep = '_' if '_' in key else '.' if '.' in key else None # separator
sep = '_' if '_' in key else None # separator
assert sep, error_string
api_key, model_id = key.split(sep)
assert len(api_key) and len(model_id), error_string
@ -172,7 +172,8 @@ class Traces:
"""
Initialize Traces for error tracking and reporting if tests are not currently running.
"""
self.rate_limit = 3.0 # rate limit (seconds)
from ultralytics.yolo.cfg import MODES, TASKS
self.rate_limit = 60.0 # rate limit (seconds)
self.t = 0.0 # rate limit timer (seconds)
self.metadata = {
'sys_argv_name': Path(sys.argv[0]).name,
@ -186,6 +187,7 @@ class Traces:
not TESTS_RUNNING and \
ONLINE and \
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
self.usage = {'tasks': {k: 0 for k in TASKS}, 'modes': {k: 0 for k in MODES}}
def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
"""
@ -197,15 +199,22 @@ class Traces:
traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0
"""
t = time.time() # current time
if self.enabled and random() < traces_sample_rate and (t - self.t) > self.rate_limit:
if not self.enabled or random() > traces_sample_rate:
# Traces disabled or not randomly selected, do nothing
return
elif (t - self.t) < self.rate_limit:
# Time is under rate limiter, do nothing
return
else:
# Time is over rate limiter, send trace now
self.t = t # reset rate limit timer
cfg = vars(cfg) # convert type from IterableSimpleNamespace to dict
if not all_keys: # filter cfg
include_keys = {'task', 'mode'} # always include
cfg = {
k: (v.split(os.sep)[-1] if isinstance(v, str) and os.sep in v else v)
for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None) or k in include_keys}
trace = {'uuid': SETTINGS['uuid'], 'cfg': cfg, 'metadata': self.metadata}
# Build trace
if cfg.task in self.usage['tasks']:
self.usage['tasks'][cfg.task] += 1
if cfg.mode in self.usage['modes']:
self.usage['modes'][cfg.mode] += 1
trace = {'uuid': SETTINGS['uuid'], 'usage': self.usage, 'metadata': self.metadata}
# Send a request to the HUB API to sync analytics
smart_request('post', f'{HUB_API_ROOT}/v1/usage/anonymous', json=trace, code=3, retry=0, verbose=False)

@ -45,7 +45,7 @@ Any of these models can be used by loading their configs or pretrained checkpoin
### 1. YOLOv8
**About** - Cutting edge Detection, Segmentation and Classification models developed by Ultralytics. </br>
**Citation** -
Available Models:
- Detection - `yolov8n`, `yolov8s`, `yolov8m`, `yolov8l`, `yolov8x`
@ -89,21 +89,28 @@ Available Models:
### 2. YOLOv5u
**About** - Anchor-free YOLOv5 models with new detection head and better speed-accuracy tradeoff </br>
**Citation** -
Available Models:
- Detection - `yolov5nu`, `yolov5su`, `yolov5mu`, `yolov5lu`, `yolov5xu`
- Detection P5/32 - `yolov5nu`, `yolov5su`, `yolov5mu`, `yolov5lu`, `yolov5xu`
- Detection P6/64 - `yolov5n6u`, `yolov5s6u`, `yolov5m6u`, `yolov5l6u`, `yolov5x6u`
<details><summary>Performance</summary>
### Detection
| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>A100 TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |
| -------------------------------------------------------------------------------------- | --------------------- | -------------------- | ------------------------------ | ----------------------------------- | ------------------ | ----------------- |
| [YOLOv5nu](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5nu.pt) | 640 | 34.3 | 73.6 | 1.06 | 2.6 | 7.7 |
| [YOLOv5su](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5su.pt) | 640 | 43.0 | 120.7 | 1.27 | 9.1 | 24.0 |
| [YOLOv5mu](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5mu.pt) | 640 | 49.0 | 233.9 | 1.86 | 25.1 | 64.2 |
| [YOLOv5lu](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5lu.pt) | 640 | 52.2 | 408.4 | 2.50 | 53.2 | 135.0 |
| [YOLOv5xu](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5xu.pt) | 640 | 53.2 | 763.2 | 3.81 | 97.2 | 246.4 |
| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>A100 TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |
| ---------------------------------------------------------------------------------------- | --------------------- | -------------------- | ------------------------------ | ----------------------------------- | ------------------ | ----------------- |
| [YOLOv5nu](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5nu.pt) | 640 | 34.3 | 73.6 | 1.06 | 2.6 | 7.7 |
| [YOLOv5su](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5su.pt) | 640 | 43.0 | 120.7 | 1.27 | 9.1 | 24.0 |
| [YOLOv5mu](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5mu.pt) | 640 | 49.0 | 233.9 | 1.86 | 25.1 | 64.2 |
| [YOLOv5lu](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5lu.pt) | 640 | 52.2 | 408.4 | 2.50 | 53.2 | 135.0 |
| [YOLOv5xu](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5xu.pt) | 640 | 53.2 | 763.2 | 3.81 | 97.2 | 246.4 |
| | | | | | | |
| [YOLOv5n6u](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5n6u.pt) | 1280 | 42.1 | - | - | 4.3 | 7.8 |
| [YOLOv5s6u](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5s6u.pt) | 1280 | 48.6 | - | - | 15.3 | 24.6 |
| [YOLOv5m6u](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5m6u.pt) | 1280 | 53.6 | - | - | 41.2 | 65.7 |
| [YOLOv5l6u](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5l6u.pt) | 1280 | 55.7 | - | - | 86.1 | 137.4 |
| [YOLOv5x6u](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov5x6u.pt) | 1280 | 56.8 | - | - | 155.4 | 250.7 |
</details>

@ -11,14 +11,28 @@ from typing import Dict, List, Union
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
IterableSimpleNamespace, __version__, checks, colorstr, yaml_load, yaml_print)
# Define valid tasks and modes
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
TASKS = 'detect', 'segment', 'classify', 'pose'
TASK2DATA = {
'detect': 'coco128.yaml',
'segment': 'coco128-seg.yaml',
'pose': 'coco128-pose.yaml',
'classify': 'imagenet100'}
TASK2MODEL = {
'detect': 'yolov8n.pt',
'segment': 'yolov8n-seg.pt',
'pose': 'yolov8n-pose.yaml',
'classify': 'yolov8n-cls.pt'} # temp
CLI_HELP_MSG = \
f"""
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
yolo TASK MODE ARGS
Where TASK (optional) is one of [detect, segment, classify]
MODE (required) is one of [train, val, predict, export, track]
Where TASK (optional) is one of {TASKS}
MODE (required) is one of {MODES}
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
@ -59,12 +73,6 @@ CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', '
'save_conf', 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms',
'retina_masks', 'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader')
# Define valid tasks and modes
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
TASKS = 'detect', 'segment', 'classify'
TASK2DATA = {'detect': 'coco128.yaml', 'segment': 'coco128-seg.yaml', 'classify': 'imagenet100'}
TASK2MODEL = {'detect': 'yolov8n.pt', 'segment': 'yolov8n-seg.pt', 'classify': 'yolov8n-cls.pt'}
def cfg2dict(cfg):
"""

@ -26,7 +26,7 @@ seed: 0 # random seed for reproducibility
deterministic: True # whether to enable deterministic mode
single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training
rect: False # support rectangular training if mode='train', support rectangular evaluation if mode='val'
rect: False # rectangular training if mode='train' or rectangular validation if mode='val'
cos_lr: False # use cosine learning rate scheduler
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
resume: False # resume training from last checkpoint

@ -278,6 +278,7 @@ def check_cls_dataset(dataset: str):
data (dict): A dictionary containing the following keys and values:
'train': Path object for the directory containing the training set of the dataset
'val': Path object for the directory containing the validation set of the dataset
'test': Path object for the directory containing the test set of the dataset
'nc': Number of classes in the dataset
'names': List of class names in the dataset
"""
@ -293,11 +294,12 @@ def check_cls_dataset(dataset: str):
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
LOGGER.info(s)
train_set = data_dir / 'train'
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
names = dict(enumerate(sorted(names)))
return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names}
return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names}
class HUBDatasetStats():

@ -230,8 +230,9 @@ class YOLO:
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(self, source=None, stream=False, **kwargs):
from ultralytics.tracker import register_tracker
register_tracker(self)
if not hasattr(self.predictor, 'trackers'):
from ultralytics.tracker import register_tracker
register_tracker(self)
# ByteTrack-based method needs low confidence predictions as input
conf = kwargs.get('conf') or 0.1
kwargs['conf'] = conf

@ -3,16 +3,16 @@
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
Usage - sources:
$ yolo mode=predict model=yolov8n.pt --source 0 # webcam
img.jpg # image
vid.mp4 # video
screen # screenshot
path/ # directory
list.txt # list of images
list.streams # list of streams
'path/*.jpg' # glob
'https://youtu.be/Zgi9g1ksQHc' # YouTube
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
$ yolo mode=predict model=yolov8n.pt source=0 # webcam
img.jpg # image
vid.mp4 # video
screen # screenshot
path/ # directory
list.txt # list of images
list.streams # list of streams
'path/*.jpg' # glob
'https://youtu.be/Zgi9g1ksQHc' # YouTube
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
Usage - formats:
$ yolo mode=predict model=yolov8n.pt # PyTorch

@ -14,6 +14,7 @@ import torchvision.transforms.functional as F
from ultralytics.yolo.utils import LOGGER, SimpleClass, ops
from ultralytics.yolo.utils.plotting import Annotator, colors
from ultralytics.yolo.utils.torch_utils import TORCHVISION_0_10
class Results(SimpleClass):
@ -129,7 +130,10 @@ class Results(SimpleClass):
if masks is not None:
im = torch.as_tensor(annotator.im, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0)
im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255
if TORCHVISION_0_10:
im = F.resize(im.contiguous(), masks.data.shape[1:], antialias=True) / 255
else:
im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im)
if probs is not None:
@ -259,7 +263,8 @@ class Masks(SimpleClass):
orig_shape (tuple): Original image size, in the format (height, width).
Properties:
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection.
xy (list): A list of segments (pixels) which includes x, y segments of each detection.
xyn (list): A list of segments (normalized) which includes x, y segments of each detection.
Methods:
cpu(): Returns a copy of the masks tensor on CPU memory.
@ -272,13 +277,28 @@ class Masks(SimpleClass):
self.masks = masks # N, h, w
self.orig_shape = orig_shape
def segments(self):
# Segments-deprecated (normalized)
LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and "
"'Masks.xy' for segments (pixels) instead.")
return self.xyn
@property
@lru_cache(maxsize=1)
def segments(self):
def xyn(self):
# Segments (normalized)
return [
ops.scale_segments(self.masks.shape[1:], x, self.orig_shape, normalize=True)
for x in ops.masks2segments(self.masks)]
@property
@lru_cache(maxsize=1)
def xy(self):
# Segments (pixels)
return [
ops.scale_segments(self.masks.shape[1:], x, self.orig_shape, normalize=False)
for x in ops.masks2segments(self.masks)]
@property
def shape(self):
return self.masks.shape

@ -370,6 +370,7 @@ class BaseTrainer:
self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow
self.run_callbacks('on_fit_epoch_end')
torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
# Early Stopping
if RANK != -1: # if DDP training

@ -484,7 +484,7 @@ def get_user_config_dir(sub_dir='Ultralytics'):
return path
USER_CONFIG_DIR = os.getenv('YOLO_CONFIG_DIR', get_user_config_dir()) # Ultralytics settings dir
USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR', get_user_config_dir())) # Ultralytics settings dir
def emojis(string=''):

@ -48,8 +48,6 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, hal
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
emoji, filename = '', None # export defaults
try:
if model.task == 'classify':
assert i != 11, 'paddle cls exports coming soon'
assert i != 9 or LINUX, 'Edge TPU export only supported on Linux'
if i == 10:
assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux'

@ -147,9 +147,10 @@ def add_integration_callbacks(instance):
from .clearml import callbacks as clearml_callbacks
from .comet import callbacks as comet_callbacks
from .hub import callbacks as hub_callbacks
from .mlflow import callbacks as mf_callbacks
from .tensorboard import callbacks as tb_callbacks
for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks:
for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks, mf_callbacks:
for k, v in x.items():
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
instance.callbacks[k].append(v) # callback[name].append(func)

@ -6,9 +6,9 @@ try:
import clearml
from clearml import Task
assert clearml.__version__ # verify package is not directory
assert hasattr(clearml, '__version__') # verify package is not directory
assert not TESTS_RUNNING # do not log pytest
except (ImportError, AssertionError, AttributeError):
except (ImportError, AssertionError):
clearml = None

@ -6,8 +6,8 @@ try:
import comet_ml
assert not TESTS_RUNNING # do not log pytest
assert comet_ml.__version__ # verify package is not directory
except (ImportError, AssertionError, AttributeError):
assert hasattr(comet_ml, '__version__') # verify package is not directory
except (ImportError, AssertionError):
comet_ml = None

@ -0,0 +1,75 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import os
import re
from pathlib import Path
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr
try:
import mlflow
assert not TESTS_RUNNING # do not log pytest
assert hasattr(mlflow, '__version__') # verify package is not directory
except (ImportError, AssertionError):
mlflow = None
def on_pretrain_routine_end(trainer):
global mlflow, run, run_id, experiment_name
if os.environ.get('MLFLOW_TRACKING_URI') is None:
mlflow = None
if mlflow:
mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
mlflow.set_tracking_uri(mlflow_location)
experiment_name = trainer.args.project or 'YOLOv8'
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment is None:
mlflow.create_experiment(experiment_name)
mlflow.set_experiment(experiment_name)
prefix = colorstr('MLFlow: ')
try:
run, active_run = mlflow, mlflow.start_run() if mlflow else None
if active_run is not None:
run_id = active_run.info.run_id
LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}')
except Exception as err:
LOGGER.error(f'{prefix}Failing init - {repr(err)}')
LOGGER.warning(f'{prefix}Continuing without Mlflow')
run = None
run.log_params(vars(trainer.model.args))
def on_fit_epoch_end(trainer):
if mlflow:
metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()}
run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
def on_model_save(trainer):
if mlflow:
run.log_artifact(trainer.last)
def on_train_end(trainer):
if mlflow:
root_dir = Path(__file__).resolve().parents[3]
run.log_artifact(trainer.best)
model_uri = f'runs:/{run_id}/'
run.register_model(model_uri, experiment_name)
run.pyfunc.log_model(artifact_path=experiment_name,
code_path=[str(root_dir)],
artifacts={'model_path': str(trainer.save_dir)},
python_model=run.pyfunc.PythonModel())
callbacks = {
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_model_save': on_model_save,
'on_train_end': on_train_end} if mlflow else {}

@ -16,10 +16,12 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__
from ultralytics.yolo.utils.checks import check_version
TORCHVISION_0_10 = check_version(torchvision.__version__, '0.10.0')
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
TORCH_1_11 = check_version(torch.__version__, '1.11.0')
TORCH_1_12 = check_version(torch.__version__, '1.12.0')

Loading…
Cancel
Save