diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 5452bb8..6312d29 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -123,12 +123,12 @@ jobs:
shell: python
run: |
from ultralytics.yolo.utils.benchmarks import benchmark
- benchmark(model='${{ matrix.model }}-cls.pt', imgsz=160, half=False, hard_fail=0.61)
+ benchmark(model='${{ matrix.model }}-cls.pt', imgsz=160, half=False, hard_fail=0.35)
- name: Benchmark PoseModel
shell: python
run: |
from ultralytics.yolo.utils.benchmarks import benchmark
- benchmark(model='${{ matrix.model }}-pose.pt', imgsz=160, half=False, hard_fail=0.0)
+ benchmark(model='${{ matrix.model }}-pose.pt', imgsz=160, half=False, hard_fail=0.17)
- name: Benchmark Summary
run: |
cat benchmarks.log
diff --git a/docs/usage/hyperparameter_tuning.md b/docs/usage/hyperparameter_tuning.md
index 06a3841..1b25ade 100644
--- a/docs/usage/hyperparameter_tuning.md
+++ b/docs/usage/hyperparameter_tuning.md
@@ -1,29 +1,26 @@
---
comments: true
-description: Discover how to integrate hyperparameter tuning with Ray Tune and Ultralytics YOLOv8. Speed up the tuning process and optimize your model's performance.
+description: Learn to integrate hyperparameter tuning using Ray Tune with Ultralytics YOLOv8, and optimize your model's performance efficiently.
keywords: yolov8, ray tune, hyperparameter tuning, hyperparameter optimization, machine learning, computer vision, deep learning, image recognition
---
-# Hyperparameter Tuning with Ray Tune and YOLOv8
+# Efficient Hyperparameter Tuning with Ray Tune and YOLOv8
-Hyperparameter tuning (or hyperparameter optimization) is the process of determining the right combination of hyperparameters that maximizes model performance. It works by running multiple trials in a single training process, evaluating the performance of each trial, and selecting the best hyperparameter values based on the evaluation results.
+Hyperparameter tuning is vital in achieving peak model performance by discovering the optimal set of hyperparameters. This involves running trials with different hyperparameters and evaluating each trial’s performance.
-## Ultralytics YOLOv8 and Ray Tune Integration
+## Accelerate Tuning with Ultralytics YOLOv8 and Ray Tune
-[Ultralytics](https://ultralytics.com) YOLOv8 integrates hyperparameter tuning with Ray Tune, allowing you to easily optimize your YOLOv8 model's hyperparameters. By using Ray Tune, you can leverage advanced search algorithms, parallelism, and early stopping to speed up the tuning process and achieve better model performance.
+[Ultralytics YOLOv8](https://ultralytics.com) incorporates Ray Tune for hyperparameter tuning, streamlining the optimization of YOLOv8 model hyperparameters. With Ray Tune, you can utilize advanced search strategies, parallelism, and early stopping to expedite the tuning process.
### Ray Tune
-
-
-
-
+![Ray Tune Overview](https://docs.ray.io/en/latest/_images/tune_overview.png)
-[Ray Tune](https://docs.ray.io/en/latest/tune/index.html) is a powerful and flexible hyperparameter tuning library for machine learning models. It provides an efficient way to optimize hyperparameters by supporting various search algorithms, parallelism, and early stopping strategies. Ray Tune's flexible architecture enables seamless integration with popular machine learning frameworks, including Ultralytics YOLOv8.
+[Ray Tune](https://docs.ray.io/en/latest/tune/index.html) is a hyperparameter tuning library designed for efficiency and flexibility. It supports various search strategies, parallelism, and early stopping strategies, and seamlessly integrates with popular machine learning frameworks, including Ultralytics YOLOv8.
-### Weights & Biases
+### Integration with Weights & Biases
-YOLOv8 also supports optional integration with [Weights & Biases](https://wandb.ai/site) (wandb) for tracking the tuning progress.
+YOLOv8 also allows optional integration with [Weights & Biases](https://wandb.ai/site) for monitoring the tuning process.
## Installation
@@ -32,8 +29,11 @@ To install the required packages, run:
!!! tip "Installation"
```bash
- pip install -U ultralytics "ray[tune]" # install and/or update
- pip install wandb # optional
+ # Install and update Ultralytics and Ray Tune pacakges
+ pip install -U ultralytics 'ray[tune]'
+
+ # Optionally install W&B for logging
+ pip install wandb
```
## Usage
@@ -44,21 +44,21 @@ To install the required packages, run:
from ultralytics import YOLO
model = YOLO("yolov8n.pt")
- results = model.tune(data="coco128.yaml")
+ result_grid = model.tune(data="coco128.yaml")
```
## `tune()` Method Parameters
The `tune()` method in YOLOv8 provides an easy-to-use interface for hyperparameter tuning with Ray Tune. It accepts several arguments that allow you to customize the tuning process. Below is a detailed explanation of each parameter:
-| Parameter | Type | Description | Default Value |
-|-----------------|----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|
-| `data` | str | The dataset configuration file (in YAML format) to run the tuner on. This file should specify the training and validation data paths, as well as other dataset-specific settings. | |
-| `space` | dict, optional | A dictionary defining the hyperparameter search space for Ray Tune. Each key corresponds to a hyperparameter name, and the value specifies the range of values to explore during tuning. If not provided, YOLOv8 uses a default search space with various hyperparameters. | |
-| `grace_period` | int, optional | The grace period in epochs for the [ASHA scheduler]https://docs.ray.io/en/latest/tune/api/schedulers.html) in Ray Tune. The scheduler will not terminate any trial before this number of epochs, allowing the model to have some minimum training before making a decision on early stopping. | 10 |
-| `gpu_per_trial` | int, optional | The number of GPUs to allocate per trial during tuning. This helps manage GPU usage, particularly in multi-GPU environments. If not provided, the tuner will use all available GPUs. | None |
-| `max_samples` | int, optional | The maximum number of trials to run during tuning. This parameter helps control the total number of hyperparameter combinations tested, ensuring the tuning process does not run indefinitely. | 10 |
-| `train_args` | dict, optional | A dictionary of additional arguments to pass to the `train()` method during tuning. These arguments can include settings like the number of training epochs, batch size, and other training-specific configurations. | {} |
+| Parameter | Type | Description | Default Value |
+|-----------------|----------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|
+| `data` | str | The dataset configuration file (in YAML format) to run the tuner on. This file should specify the training and validation data paths, as well as other dataset-specific settings. | |
+| `space` | dict, optional | A dictionary defining the hyperparameter search space for Ray Tune. Each key corresponds to a hyperparameter name, and the value specifies the range of values to explore during tuning. If not provided, YOLOv8 uses a default search space with various hyperparameters. | |
+| `grace_period` | int, optional | The grace period in epochs for the [ASHA scheduler](https://docs.ray.io/en/latest/tune/api/schedulers.html) in Ray Tune. The scheduler will not terminate any trial before this number of epochs, allowing the model to have some minimum training before making a decision on early stopping. | 10 |
+| `gpu_per_trial` | int, optional | The number of GPUs to allocate per trial during tuning. This helps manage GPU usage, particularly in multi-GPU environments. If not provided, the tuner will use all available GPUs. | None |
+| `max_samples` | int, optional | The maximum number of trials to run during tuning. This parameter helps control the total number of hyperparameter combinations tested, ensuring the tuning process does not run indefinitely. | 10 |
+| `**train_args` | dict, optional | Additional arguments to pass to the `train()` method during tuning. These arguments can include settings like the number of training epochs, batch size, and other training-specific configurations. | {} |
By customizing these parameters, you can fine-tune the hyperparameter optimization process to suit your specific needs and available computational resources.
@@ -98,14 +98,72 @@ In this example, we demonstrate how to use a custom search space for hyperparame
```python
from ultralytics import YOLO
- from ray import tune
-
+
+ # Define a YOLO model
model = YOLO("yolov8n.pt")
- result = model.tune(
- data="coco128.yaml",
- space={"lr0": tune.uniform(1e-5, 1e-1)},
- train_args={"epochs": 50}
- )
+
+ # Run Ray Tune on the model
+ result_grid = model.tune(data="coco128.yaml",
+ space={"lr0": tune.uniform(1e-5, 1e-1)},
+ epochs=50)
```
-In the code snippet above, we create a YOLO model with the "yolov8n.pt" pretrained weights. Then, we call the `tune()` method, specifying the dataset configuration with "coco128.yaml". We provide a custom search space for the initial learning rate `lr0` using a dictionary with the key "lr0" and the value `tune.uniform(1e-5, 1e-1)`. Finally, we pass additional training arguments, such as the number of epochs, using the `train_args` parameter.
\ No newline at end of file
+In the code snippet above, we create a YOLO model with the "yolov8n.pt" pretrained weights. Then, we call the `tune()` method, specifying the dataset configuration with "coco128.yaml". We provide a custom search space for the initial learning rate `lr0` using a dictionary with the key "lr0" and the value `tune.uniform(1e-5, 1e-1)`. Finally, we pass additional training arguments, such as the number of epochs directly to the tune method as `epochs=50`.
+
+# Processing Ray Tune Results
+
+After running a hyperparameter tuning experiment with Ray Tune, you might want to perform various analyses on the obtained results. This guide will take you through common workflows for processing and analyzing these results.
+
+## Loading Tune Experiment Results from a Directory
+
+After running the tuning experiment with `tuner.fit()`, you can load the results from a directory. This is useful, especially if you're performing the analysis after the initial training script has exited.
+
+```python
+experiment_path = f"{storage_path}/{exp_name}"
+print(f"Loading results from {experiment_path}...")
+
+restored_tuner = tune.Tuner.restore(experiment_path, trainable=train_mnist)
+result_grid = restored_tuner.get_results()
+```
+
+## Basic Experiment-Level Analysis
+
+Get an overview of how trials performed. You can quickly check if there were any errors during the trials.
+
+```python
+if result_grid.errors:
+ print("One or more trials failed!")
+else:
+ print("No errors!")
+```
+
+## Basic Trial-Level Analysis
+
+Access individual trial hyperparameter configurations and the last reported metrics.
+
+```python
+for i, result in enumerate(result_grid):
+ print(f"Trial #{i}: Configuration: {result.config}, Last Reported Metrics: {result.metrics}")
+```
+
+## Plotting the Entire History of Reported Metrics for a Trial
+
+You can plot the history of reported metrics for each trial to see how the metrics evolved over time.
+
+```python
+import matplotlib.pyplot as plt
+
+for result in result_grid:
+ plt.plot(result.metrics_dataframe["training_iteration"], result.metrics_dataframe["mean_accuracy"], label=f"Trial {i}")
+
+plt.xlabel('Training Iterations')
+plt.ylabel('Mean Accuracy')
+plt.legend()
+plt.show()
+```
+
+## Summary
+
+In this documentation, we covered common workflows to analyze the results of experiments run with Ray Tune using Ultralytics. The key steps include loading the experiment results from a directory, performing basic experiment-level and trial-level analysis and plotting metrics.
+
+Explore further by looking into Ray Tune’s [Analyze Results](https://docs.ray.io/en/latest/tune/examples/tune_analyze_results.html) docs page to get the most out of your hyperparameter tuning experiments.
diff --git a/mkdocs.yml b/mkdocs.yml
index e8ab7dc..c093894 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -55,7 +55,7 @@ theme:
- content.tabs.link # all code tabs change simultaneously
# Customization
-copyright: Ultralytics 2023. All rights reserved.
+copyright: © 2023 Ultralytics Inc. All rights reserved.
extra:
# version:
# provider: mike # version drop-down menu
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index f5ae6ab..c86adef 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = '8.0.125'
+__version__ = '8.0.126'
from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR
diff --git a/ultralytics/yolo/cfg/__init__.py b/ultralytics/yolo/cfg/__init__.py
index 746e458..c3d0285 100644
--- a/ultralytics/yolo/cfg/__init__.py
+++ b/ultralytics/yolo/cfg/__init__.py
@@ -26,6 +26,12 @@ TASK2MODEL = {
'segment': 'yolov8n-seg.pt',
'classify': 'yolov8n-cls.pt',
'pose': 'yolov8n-pose.pt'}
+TASK2METRIC = {
+ 'detect': 'metrics/mAP50-95(B)',
+ 'segment': 'metrics/mAP50-95(M)',
+ 'classify': 'metrics/accuracy_top1',
+ 'pose': 'metrics/mAP50-95(P)'}
+
CLI_HELP_MSG = \
f"""
diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py
index 18810e7..2019566 100644
--- a/ultralytics/yolo/engine/model.py
+++ b/ultralytics/yolo/engine/model.py
@@ -9,8 +9,8 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel
attempt_load_one_weight, guess_model_task, nn, yaml_model_load)
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
-from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, NUM_THREADS, RANK, ROOT,
- callbacks, is_git_dir, yaml_load)
+from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
+ is_git_dir, yaml_load)
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
@@ -387,13 +387,7 @@ class YOLO:
self._check_is_pytorch_model()
self.model.to(device)
- def tune(self,
- data: str,
- space: dict = None,
- grace_period: int = 10,
- gpu_per_trial: int = None,
- max_samples: int = 10,
- train_args: dict = None):
+ def tune(self, *args, **kwargs):
"""
Runs hyperparameter tuning using Ray Tune.
@@ -411,66 +405,9 @@ class YOLO:
Raises:
ModuleNotFoundError: If Ray Tune is not installed.
"""
- if train_args is None:
- train_args = {}
-
- try:
- from ultralytics.yolo.utils.tuner import (ASHAScheduler, RunConfig, WandbLoggerCallback, default_space,
- task_metric_map, tune)
- except ImportError:
- raise ModuleNotFoundError("Install Ray Tune: `pip install 'ray[tune]'`")
-
- try:
- import wandb
- from wandb import __version__ # noqa
- except ImportError:
- wandb = False
-
- def _tune(config):
- """
- Trains the YOLO model with the specified hyperparameters and additional arguments.
-
- Args:
- config (dict): A dictionary of hyperparameters to use for training.
-
- Returns:
- None.
- """
- self._reset_callbacks()
- config.update(train_args)
- self.train(**config)
-
- if not space:
- LOGGER.warning('WARNING: search space not provided. Using default search space')
- space = default_space
-
- space['data'] = data
-
- # Define the trainable function with allocated resources
- trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0})
-
- # Define the ASHA scheduler for hyperparameter search
- asha_scheduler = ASHAScheduler(time_attr='epoch',
- metric=task_metric_map[self.task],
- mode='max',
- max_t=train_args.get('epochs') or 100,
- grace_period=grace_period,
- reduction_factor=3)
-
- # Define the callbacks for the hyperparameter search
- tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else []
-
- # Create the Ray Tune hyperparameter search tuner
- tuner = tune.Tuner(trainable_with_resources,
- param_space=space,
- tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
- run_config=RunConfig(callbacks=tuner_callbacks, local_dir='./runs'))
-
- # Run the hyperparameter search
- tuner.fit()
-
- # Return the results of the hyperparameter search
- return tuner.get_results()
+ self._check_is_pytorch_model()
+ from ultralytics.yolo.utils.tuner import run_ray_tune
+ return run_ray_tune(self, *args, **kwargs)
@property
def names(self):
diff --git a/ultralytics/yolo/utils/benchmarks.py b/ultralytics/yolo/utils/benchmarks.py
index a277d6b..654847b 100644
--- a/ultralytics/yolo/utils/benchmarks.py
+++ b/ultralytics/yolo/utils/benchmarks.py
@@ -33,6 +33,7 @@ import torch.cuda
from tqdm import tqdm
from ultralytics import YOLO
+from ultralytics.yolo.cfg import TASK2DATA, TASK2METRIC
from ultralytics.yolo.engine.exporter import export_formats
from ultralytics.yolo.utils import LINUX, LOGGER, MACOS, ROOT, SETTINGS
from ultralytics.yolo.utils.checks import check_requirements, check_yolo
@@ -96,6 +97,7 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
emoji = '❎' # indicates export succeeded
# Predict
+ assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
if not (ROOT / 'assets/bus.jpg').exists():
@@ -103,15 +105,8 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
export.predict(ROOT / 'assets/bus.jpg', imgsz=imgsz, device=device, half=half)
# Validate
- if model.task == 'detect':
- data, key = 'coco8.yaml', 'metrics/mAP50-95(B)'
- elif model.task == 'segment':
- data, key = 'coco8-seg.yaml', 'metrics/mAP50-95(M)'
- elif model.task == 'classify':
- data, key = 'imagenet100', 'metrics/accuracy_top5'
- elif model.task == 'pose':
- data, key = 'coco8-pose.yaml', 'metrics/mAP50-95(P)'
-
+ data = TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
+ key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
results = export.val(data=data,
batch=1,
imgsz=imgsz,
diff --git a/ultralytics/yolo/utils/tuner.py b/ultralytics/yolo/utils/tuner.py
index 9f57677..54f10b0 100644
--- a/ultralytics/yolo/utils/tuner.py
+++ b/ultralytics/yolo/utils/tuner.py
@@ -1,44 +1,120 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
+from ultralytics.yolo.cfg import TASK2DATA, TASK2METRIC
+from ultralytics.yolo.utils import DEFAULT_CFG_DICT, LOGGER, NUM_THREADS
-from ultralytics.yolo.utils import LOGGER
-
-try:
- from ray import tune
- from ray.air import RunConfig, session # noqa
- from ray.air.integrations.wandb import WandbLoggerCallback # noqa
- from ray.tune.schedulers import ASHAScheduler # noqa
- from ray.tune.schedulers import AsyncHyperBandScheduler as AHB # noqa
-
-except ImportError:
- LOGGER.info("Tuning hyperparameters requires ray/tune. Install using `pip install 'ray[tune]'`")
- tune = None
-
-default_space = {
- # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
- 'lr0': tune.uniform(1e-5, 1e-1),
- 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
- 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
- 'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
- 'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
- 'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
- 'box': tune.uniform(0.02, 0.2), # box loss gain
- 'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
- 'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
- 'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
- 'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
- 'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
- 'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
- 'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
- 'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
- 'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
- 'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
- 'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
- 'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
- 'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
- 'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
-
-task_metric_map = {
- 'detect': 'metrics/mAP50-95(B)',
- 'segment': 'metrics/mAP50-95(M)',
- 'classify': 'metrics/accuracy_top1',
- 'pose': 'metrics/mAP50-95(P)'}
+
+def run_ray_tune(model,
+ space: dict = None,
+ grace_period: int = 10,
+ gpu_per_trial: int = None,
+ max_samples: int = 10,
+ **train_args):
+ """
+ Runs hyperparameter tuning using Ray Tune.
+
+ Args:
+ model (YOLO): Model to run the tuner on.
+ space (dict, optional): The hyperparameter search space. Defaults to None.
+ grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
+ gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
+ max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
+ train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
+
+ Returns:
+ (dict): A dictionary containing the results of the hyperparameter search.
+
+ Raises:
+ ModuleNotFoundError: If Ray Tune is not installed.
+ """
+ if train_args is None:
+ train_args = {}
+
+ try:
+ from ray import tune
+ from ray.air import RunConfig
+ from ray.air.integrations.wandb import WandbLoggerCallback
+ from ray.tune.schedulers import ASHAScheduler
+ except ImportError:
+ raise ModuleNotFoundError("Tuning hyperparameters requires Ray Tune. Install with: pip install 'ray[tune]'")
+
+ try:
+ import wandb
+
+ assert hasattr(wandb, '__version__')
+ except (ImportError, AssertionError):
+ wandb = False
+
+ default_space = {
+ # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
+ 'lr0': tune.uniform(1e-5, 1e-1),
+ 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
+ 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
+ 'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
+ 'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
+ 'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
+ 'box': tune.uniform(0.02, 0.2), # box loss gain
+ 'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
+ 'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
+ 'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
+ 'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
+ 'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
+ 'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
+ 'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
+ 'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
+ 'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
+ 'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
+ 'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
+ 'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
+ 'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
+ 'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
+
+ def _tune(config):
+ """
+ Trains the YOLO model with the specified hyperparameters and additional arguments.
+
+ Args:
+ config (dict): A dictionary of hyperparameters to use for training.
+
+ Returns:
+ None.
+ """
+ model._reset_callbacks()
+ config.update(train_args)
+ model.train(**config)
+
+ # Get search space
+ if not space:
+ space = default_space
+ LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.')
+
+ # Get dataset
+ data = train_args.get('data', TASK2DATA[model.task])
+ space['data'] = data
+ if 'data' not in train_args:
+ LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
+
+ # Define the trainable function with allocated resources
+ trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0})
+
+ # Define the ASHA scheduler for hyperparameter search
+ asha_scheduler = ASHAScheduler(time_attr='epoch',
+ metric=TASK2METRIC[model.task],
+ mode='max',
+ max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100,
+ grace_period=grace_period,
+ reduction_factor=3)
+
+ # Define the callbacks for the hyperparameter search
+ tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else []
+
+ # Create the Ray Tune hyperparameter search tuner
+ tuner = tune.Tuner(trainable_with_resources,
+ param_space=space,
+ tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
+ run_config=RunConfig(callbacks=tuner_callbacks, storage_path='./runs/tune'))
+
+ # Run the hyperparameter search
+ tuner.fit()
+
+ # Return the results of the hyperparameter search
+ return tuner.get_results()