`ultralytics 8.0.77` Ray[Tune] for hyperparameter optimization (#2014)

Co-authored-by: JF Chen <k-2feng@hotmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 4916014af2
commit 5065ca36a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,7 @@ on:
push:
branches: [main]
pull_request:
branches: [main, updates]
branches: [main]
schedule:
- cron: '0 0 * * *' # runs at 00:00 UTC every day
@ -76,9 +76,9 @@ jobs:
run: |
python -m pip install --upgrade pip wheel
if [ "${{ matrix.os }}" == "macos-latest" ]; then
pip install -e '.[export-macos]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[export]' --extra-index-url https://download.pytorch.org/whl/cpu
else
pip install -e '.[export-cpu]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[export]' --extra-index-url https://download.pytorch.org/whl/cpu
fi
yolo export format=tflite imgsz=32
- name: Check environment

@ -30,7 +30,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u
# Install pip packages
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache . albumentations comet gsutil notebook
RUN pip install --no-cache . albumentations comet gsutil notebook tensorboard
# Set environment variables
ENV OMP_NUM_THREADS=1

@ -26,7 +26,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u
# Install pip packages
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache . albumentations gsutil notebook \
RUN pip install --no-cache . albumentations gsutil notebook tensorboard \
--extra-index-url https://download.pytorch.org/whl/cpu
# Cleanup

@ -39,9 +39,8 @@ setup(
install_requires=REQUIREMENTS + PKG_REQUIREMENTS,
extras_require={
'dev': ['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs-material', 'mkdocstrings[python]'],
'export': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow', 'tensorflowjs'],
'export-cpu': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow-cpu', 'tensorflowjs'],
'export-macos': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow-macos', 'tensorflowjs']},
'export': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflowjs'], # automatically installs tensorflow
},
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.76'
__version__ = '8.0.77'
from ultralytics.hub import start
from ultralytics.yolo.engine.model import YOLO

@ -381,6 +381,89 @@ class YOLO:
self._check_is_pytorch_model()
self.model.to(device)
def tune(self,
data: str,
space: dict = None,
grace_period: int = 10,
gpu_per_trial: int = None,
max_samples: int = 10,
train_args: dict = {}):
"""
Runs hyperparameter tuning using Ray Tune.
Args:
data (str): The dataset to run the tuner on.
space (dict, optional): The hyperparameter search space. Defaults to None.
grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
Returns:
A dictionary containing the results of the hyperparameter search.
Raises:
ModuleNotFoundError: If Ray Tune is not installed.
"""
try:
from ultralytics.yolo.utils.tuner import (ASHAScheduler, RunConfig, WandbLoggerCallback, default_space,
task_metric_map, tune)
except ImportError:
raise ModuleNotFoundError("Install Ray Tune: `pip install 'ray[tune]'`")
try:
import wandb
from wandb import __version__ # noqa
except ImportError:
wandb = False
def _tune(config):
"""
Trains the YOLO model with the specified hyperparameters and additional arguments.
Args:
config (dict): A dictionary of hyperparameters to use for training.
Returns:
None.
"""
self._reset_callbacks()
config.update(train_args)
self.train(**config)
if not space:
LOGGER.warning('WARNING: search space not provided. Using default search space')
space = default_space
space['data'] = data
# Define the trainable function with allocated resources
trainable_with_resources = tune.with_resources(_tune, {'cpu': 8, 'gpu': gpu_per_trial if gpu_per_trial else 0})
# Define the ASHA scheduler for hyperparameter search
asha_scheduler = ASHAScheduler(time_attr='epoch',
metric=task_metric_map[self.task],
mode='max',
max_t=train_args.get('epochs') or 100,
grace_period=grace_period,
reduction_factor=3)
# Define the callbacks for the hyperparameter search
tuner_callbacks = [WandbLoggerCallback(project='yolov8_tune') if wandb else None]
# Create the Ray Tune hyperparameter search tuner
tuner = tune.Tuner(trainable_with_resources,
param_space=space,
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
run_config=RunConfig(callbacks=tuner_callbacks, local_dir='./runs'))
# Run the hyperparameter search
tuner.fit()
# Return the results of the hyperparameter search
return tuner.get_results()
@property
def names(self):
"""

@ -113,7 +113,7 @@ class BaseTrainer:
if self.device.type == 'cpu':
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataloaders.
# Model and Dataset
self.model = self.args.model
try:
if self.args.task == 'classify':
@ -243,7 +243,7 @@ class BaseTrainer:
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
# dataloaders
# Dataloaders
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
if RANK in (-1, 0):

@ -154,9 +154,11 @@ def add_integration_callbacks(instance):
from .comet import callbacks as comet_callbacks
from .hub import callbacks as hub_callbacks
from .mlflow import callbacks as mf_callbacks
from .raytune import callbacks as tune_callbacks
from .tensorboard import callbacks as tb_callbacks
from .wb import callbacks as wb_callbacks
for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks, mf_callbacks:
for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks, mf_callbacks, tune_callbacks, wb_callbacks:
for k, v in x.items():
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
instance.callbacks[k].append(v) # callback[name].append(func)

@ -0,0 +1,17 @@
try:
import ray
from ray import tune
from ray.air import session
except (ImportError, AssertionError):
tune = None
def on_fit_epoch_end(trainer):
if ray.tune.is_session_enabled():
metrics = trainer.metrics
metrics['epoch'] = trainer.epoch
session.report(metrics)
callbacks = {
'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}

@ -0,0 +1,48 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
try:
import wandb as wb
assert hasattr(wb, '__version__')
except (ImportError, AssertionError):
wb = None
def on_pretrain_routine_start(trainer):
wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(
trainer.args)) if not wb.run else wb.run
def on_fit_epoch_end(trainer):
wb.run.log(trainer.metrics, step=trainer.epoch + 1)
if trainer.epoch == 0:
model_info = {
'model/parameters': get_num_params(trainer.model),
'model/GFLOPs': round(get_flops(trainer.model), 3),
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
wb.run.log(model_info, step=trainer.epoch + 1)
def on_train_epoch_end(trainer):
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
wb.run.log(trainer.lr, step=trainer.epoch + 1)
if trainer.epoch == 1:
wb.run.log({f.stem: wb.Image(str(f))
for f in trainer.save_dir.glob('train_batch*.jpg')},
step=trainer.epoch + 1)
def on_train_end(trainer):
art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
if trainer.best.exists():
art.add_file(trainer.best)
wb.run.log_artifact(art)
callbacks = {
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if wb else {}

@ -0,0 +1,43 @@
from ultralytics.yolo.utils import LOGGER
try:
from ray import tune
from ray.air import RunConfig, session # noqa
from ray.air.integrations.wandb import WandbLoggerCallback # noqa
from ray.tune.schedulers import ASHAScheduler # noqa
from ray.tune.schedulers import AsyncHyperBandScheduler as AHB # noqa
except ImportError:
LOGGER.info("Tuning hyperparameters requires ray/tune. Install using `pip install 'ray[tune]'`")
tune = None
default_space = {
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'RMSProp']),
'lr0': tune.uniform(1e-5, 1e-1),
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
'box': tune.uniform(0.02, 0.2), # box loss gain
'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
'fl_gamma': tune.uniform(0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
task_metric_map = {
'detect': 'metrics/mAP50-95(B)',
'segment': 'metrics/mAP50-95(M)',
'classify': 'top1_acc',
'pose': None}

@ -21,7 +21,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
# BaseTrainer python usage
class DetectionTrainer(BaseTrainer):
def get_dataloader(self, dataset_path, batch_size, mode='train', rank=0):
def get_dataloader(self, dataset_path, batch_size, rank=0, mode='train'):
# TODO: manage splits differently
# calculate stride - check if model is initialized
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)

Loading…
Cancel
Save