ultralytics 8.0.77 Ray[Tune] for hyperparameter optimization (#2014)

Co-authored-by: JF Chen <k-2feng@hotmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-04-14 01:28:34 +02:00
committed by GitHub
parent 4916014af2
commit 5065ca36a8
12 changed files with 205 additions and 13 deletions

View File

@ -381,6 +381,89 @@ class YOLO:
self._check_is_pytorch_model()
self.model.to(device)
def tune(self,
data: str,
space: dict = None,
grace_period: int = 10,
gpu_per_trial: int = None,
max_samples: int = 10,
train_args: dict = {}):
"""
Runs hyperparameter tuning using Ray Tune.
Args:
data (str): The dataset to run the tuner on.
space (dict, optional): The hyperparameter search space. Defaults to None.
grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
Returns:
A dictionary containing the results of the hyperparameter search.
Raises:
ModuleNotFoundError: If Ray Tune is not installed.
"""
try:
from ultralytics.yolo.utils.tuner import (ASHAScheduler, RunConfig, WandbLoggerCallback, default_space,
task_metric_map, tune)
except ImportError:
raise ModuleNotFoundError("Install Ray Tune: `pip install 'ray[tune]'`")
try:
import wandb
from wandb import __version__ # noqa
except ImportError:
wandb = False
def _tune(config):
"""
Trains the YOLO model with the specified hyperparameters and additional arguments.
Args:
config (dict): A dictionary of hyperparameters to use for training.
Returns:
None.
"""
self._reset_callbacks()
config.update(train_args)
self.train(**config)
if not space:
LOGGER.warning('WARNING: search space not provided. Using default search space')
space = default_space
space['data'] = data
# Define the trainable function with allocated resources
trainable_with_resources = tune.with_resources(_tune, {'cpu': 8, 'gpu': gpu_per_trial if gpu_per_trial else 0})
# Define the ASHA scheduler for hyperparameter search
asha_scheduler = ASHAScheduler(time_attr='epoch',
metric=task_metric_map[self.task],
mode='max',
max_t=train_args.get('epochs') or 100,
grace_period=grace_period,
reduction_factor=3)
# Define the callbacks for the hyperparameter search
tuner_callbacks = [WandbLoggerCallback(project='yolov8_tune') if wandb else None]
# Create the Ray Tune hyperparameter search tuner
tuner = tune.Tuner(trainable_with_resources,
param_space=space,
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
run_config=RunConfig(callbacks=tuner_callbacks, local_dir='./runs'))
# Run the hyperparameter search
tuner.fit()
# Return the results of the hyperparameter search
return tuner.get_results()
@property
def names(self):
"""

View File

@ -113,7 +113,7 @@ class BaseTrainer:
if self.device.type == 'cpu':
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataloaders.
# Model and Dataset
self.model = self.args.model
try:
if self.args.task == 'classify':
@ -243,7 +243,7 @@ class BaseTrainer:
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
# dataloaders
# Dataloaders
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
if RANK in (-1, 0):