From 5065ca36a8499b43dd4b938ba73ddb98bc6e278f Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 14 Apr 2023 01:28:34 +0200 Subject: [PATCH] `ultralytics 8.0.77` Ray[Tune] for hyperparameter optimization (#2014) Co-authored-by: JF Chen Co-authored-by: Ayush Chaurasia Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/ci.yaml | 6 +- docker/Dockerfile | 2 +- docker/Dockerfile-cpu | 2 +- setup.py | 5 +- ultralytics/__init__.py | 2 +- ultralytics/yolo/engine/model.py | 83 +++++++++++++++++++++ ultralytics/yolo/engine/trainer.py | 4 +- ultralytics/yolo/utils/callbacks/base.py | 4 +- ultralytics/yolo/utils/callbacks/raytune.py | 17 +++++ ultralytics/yolo/utils/callbacks/wb.py | 48 ++++++++++++ ultralytics/yolo/utils/tuner.py | 43 +++++++++++ ultralytics/yolo/v8/detect/train.py | 2 +- 12 files changed, 205 insertions(+), 13 deletions(-) create mode 100644 ultralytics/yolo/utils/callbacks/raytune.py create mode 100644 ultralytics/yolo/utils/callbacks/wb.py create mode 100644 ultralytics/yolo/utils/tuner.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 723886d..d108357 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -7,7 +7,7 @@ on: push: branches: [main] pull_request: - branches: [main, updates] + branches: [main] schedule: - cron: '0 0 * * *' # runs at 00:00 UTC every day @@ -76,9 +76,9 @@ jobs: run: | python -m pip install --upgrade pip wheel if [ "${{ matrix.os }}" == "macos-latest" ]; then - pip install -e '.[export-macos]' --extra-index-url https://download.pytorch.org/whl/cpu + pip install -e '.[export]' --extra-index-url https://download.pytorch.org/whl/cpu else - pip install -e '.[export-cpu]' --extra-index-url https://download.pytorch.org/whl/cpu + pip install -e '.[export]' --extra-index-url https://download.pytorch.org/whl/cpu fi yolo export format=tflite imgsz=32 - name: Check environment diff --git a/docker/Dockerfile b/docker/Dockerfile index c79518a..f84daee 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -30,7 +30,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u # Install pip packages RUN python3 -m pip install --upgrade pip wheel -RUN pip install --no-cache . albumentations comet gsutil notebook +RUN pip install --no-cache . albumentations comet gsutil notebook tensorboard # Set environment variables ENV OMP_NUM_THREADS=1 diff --git a/docker/Dockerfile-cpu b/docker/Dockerfile-cpu index a9c3d79..f0d16e5 100644 --- a/docker/Dockerfile-cpu +++ b/docker/Dockerfile-cpu @@ -26,7 +26,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u # Install pip packages RUN python3 -m pip install --upgrade pip wheel -RUN pip install --no-cache . albumentations gsutil notebook \ +RUN pip install --no-cache . albumentations gsutil notebook tensorboard \ --extra-index-url https://download.pytorch.org/whl/cpu # Cleanup diff --git a/setup.py b/setup.py index a0c44c3..e9dae87 100644 --- a/setup.py +++ b/setup.py @@ -39,9 +39,8 @@ setup( install_requires=REQUIREMENTS + PKG_REQUIREMENTS, extras_require={ 'dev': ['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs-material', 'mkdocstrings[python]'], - 'export': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow', 'tensorflowjs'], - 'export-cpu': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow-cpu', 'tensorflowjs'], - 'export-macos': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow-macos', 'tensorflowjs']}, + 'export': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflowjs'], # automatically installs tensorflow + }, classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index cf94ef8..e0c8e20 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = '8.0.76' +__version__ = '8.0.77' from ultralytics.hub import start from ultralytics.yolo.engine.model import YOLO diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 6361722..f7f99f3 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -381,6 +381,89 @@ class YOLO: self._check_is_pytorch_model() self.model.to(device) + def tune(self, + data: str, + space: dict = None, + grace_period: int = 10, + gpu_per_trial: int = None, + max_samples: int = 10, + train_args: dict = {}): + """ + Runs hyperparameter tuning using Ray Tune. + + Args: + data (str): The dataset to run the tuner on. + space (dict, optional): The hyperparameter search space. Defaults to None. + grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10. + gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None. + max_samples (int, optional): The maximum number of trials to run. Defaults to 10. + train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}. + + Returns: + A dictionary containing the results of the hyperparameter search. + + Raises: + ModuleNotFoundError: If Ray Tune is not installed. + """ + + try: + from ultralytics.yolo.utils.tuner import (ASHAScheduler, RunConfig, WandbLoggerCallback, default_space, + task_metric_map, tune) + except ImportError: + raise ModuleNotFoundError("Install Ray Tune: `pip install 'ray[tune]'`") + + try: + import wandb + from wandb import __version__ # noqa + except ImportError: + wandb = False + + def _tune(config): + """ + Trains the YOLO model with the specified hyperparameters and additional arguments. + + Args: + config (dict): A dictionary of hyperparameters to use for training. + + Returns: + None. + """ + self._reset_callbacks() + config.update(train_args) + self.train(**config) + + if not space: + LOGGER.warning('WARNING: search space not provided. Using default search space') + space = default_space + + space['data'] = data + + # Define the trainable function with allocated resources + trainable_with_resources = tune.with_resources(_tune, {'cpu': 8, 'gpu': gpu_per_trial if gpu_per_trial else 0}) + + # Define the ASHA scheduler for hyperparameter search + asha_scheduler = ASHAScheduler(time_attr='epoch', + metric=task_metric_map[self.task], + mode='max', + max_t=train_args.get('epochs') or 100, + grace_period=grace_period, + reduction_factor=3) + + # Define the callbacks for the hyperparameter search + tuner_callbacks = [WandbLoggerCallback(project='yolov8_tune') if wandb else None] + + # Create the Ray Tune hyperparameter search tuner + tuner = tune.Tuner(trainable_with_resources, + param_space=space, + tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples), + run_config=RunConfig(callbacks=tuner_callbacks, local_dir='./runs')) + + # Run the hyperparameter search + tuner.fit() + + # Return the results of the hyperparameter search + return tuner.get_results() + @property def names(self): """ diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index b920465..2bd282f 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -113,7 +113,7 @@ class BaseTrainer: if self.device.type == 'cpu': self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading - # Model and Dataloaders. + # Model and Dataset self.model = self.args.model try: if self.args.task == 'classify': @@ -243,7 +243,7 @@ class BaseTrainer: self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False - # dataloaders + # Dataloaders batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train') if RANK in (-1, 0): diff --git a/ultralytics/yolo/utils/callbacks/base.py b/ultralytics/yolo/utils/callbacks/base.py index 0779594..e372c61 100644 --- a/ultralytics/yolo/utils/callbacks/base.py +++ b/ultralytics/yolo/utils/callbacks/base.py @@ -154,9 +154,11 @@ def add_integration_callbacks(instance): from .comet import callbacks as comet_callbacks from .hub import callbacks as hub_callbacks from .mlflow import callbacks as mf_callbacks + from .raytune import callbacks as tune_callbacks from .tensorboard import callbacks as tb_callbacks + from .wb import callbacks as wb_callbacks - for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks, mf_callbacks: + for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks, mf_callbacks, tune_callbacks, wb_callbacks: for k, v in x.items(): if v not in instance.callbacks[k]: # prevent duplicate callbacks addition instance.callbacks[k].append(v) # callback[name].append(func) diff --git a/ultralytics/yolo/utils/callbacks/raytune.py b/ultralytics/yolo/utils/callbacks/raytune.py new file mode 100644 index 0000000..a57b4f4 --- /dev/null +++ b/ultralytics/yolo/utils/callbacks/raytune.py @@ -0,0 +1,17 @@ +try: + import ray + from ray import tune + from ray.air import session +except (ImportError, AssertionError): + tune = None + + +def on_fit_epoch_end(trainer): + if ray.tune.is_session_enabled(): + metrics = trainer.metrics + metrics['epoch'] = trainer.epoch + session.report(metrics) + + +callbacks = { + 'on_fit_epoch_end': on_fit_epoch_end, } if tune else {} diff --git a/ultralytics/yolo/utils/callbacks/wb.py b/ultralytics/yolo/utils/callbacks/wb.py new file mode 100644 index 0000000..7e0a087 --- /dev/null +++ b/ultralytics/yolo/utils/callbacks/wb.py @@ -0,0 +1,48 @@ +# Ultralytics YOLO 🚀, GPL-3.0 license + +from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params + +try: + import wandb as wb + + assert hasattr(wb, '__version__') +except (ImportError, AssertionError): + wb = None + + +def on_pretrain_routine_start(trainer): + wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars( + trainer.args)) if not wb.run else wb.run + + +def on_fit_epoch_end(trainer): + wb.run.log(trainer.metrics, step=trainer.epoch + 1) + if trainer.epoch == 0: + model_info = { + 'model/parameters': get_num_params(trainer.model), + 'model/GFLOPs': round(get_flops(trainer.model), 3), + 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} + wb.run.log(model_info, step=trainer.epoch + 1) + + +def on_train_epoch_end(trainer): + wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1) + wb.run.log(trainer.lr, step=trainer.epoch + 1) + if trainer.epoch == 1: + wb.run.log({f.stem: wb.Image(str(f)) + for f in trainer.save_dir.glob('train_batch*.jpg')}, + step=trainer.epoch + 1) + + +def on_train_end(trainer): + art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model') + if trainer.best.exists(): + art.add_file(trainer.best) + wb.run.log_artifact(art) + + +callbacks = { + 'on_pretrain_routine_start': on_pretrain_routine_start, + 'on_train_epoch_end': on_train_epoch_end, + 'on_fit_epoch_end': on_fit_epoch_end, + 'on_train_end': on_train_end} if wb else {} diff --git a/ultralytics/yolo/utils/tuner.py b/ultralytics/yolo/utils/tuner.py new file mode 100644 index 0000000..3c69d1d --- /dev/null +++ b/ultralytics/yolo/utils/tuner.py @@ -0,0 +1,43 @@ +from ultralytics.yolo.utils import LOGGER + +try: + from ray import tune + from ray.air import RunConfig, session # noqa + from ray.air.integrations.wandb import WandbLoggerCallback # noqa + from ray.tune.schedulers import ASHAScheduler # noqa + from ray.tune.schedulers import AsyncHyperBandScheduler as AHB # noqa + +except ImportError: + LOGGER.info("Tuning hyperparameters requires ray/tune. Install using `pip install 'ray[tune]'`") + tune = None + +default_space = { + # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'RMSProp']), + 'lr0': tune.uniform(1e-5, 1e-1), + 'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) + 'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 + 'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4 + 'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) + 'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum + 'box': tune.uniform(0.02, 0.2), # box loss gain + 'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) + 'fl_gamma': tune.uniform(0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5) + 'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) + 'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) + 'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) + 'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg) + 'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction) + 'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain) + 'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg) + 'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + 'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability) + 'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability) + 'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability) + 'mixup': tune.uniform(0.0, 1.0), # image mixup (probability) + 'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability) + +task_metric_map = { + 'detect': 'metrics/mAP50-95(B)', + 'segment': 'metrics/mAP50-95(M)', + 'classify': 'top1_acc', + 'pose': None} diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index 33d463d..5a8a4ca 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -21,7 +21,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel # BaseTrainer python usage class DetectionTrainer(BaseTrainer): - def get_dataloader(self, dataset_path, batch_size, mode='train', rank=0): + def get_dataloader(self, dataset_path, batch_size, rank=0, mode='train'): # TODO: manage splits differently # calculate stride - check if model is initialized gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)