diff --git a/requirements.txt b/requirements.txt index 5eb68d4..c062769 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ tqdm>=4.64.0 # Logging ------------------------------------- # tensorboard>=2.13.0 +# dvclive>=2.11.0 # clearml # comet diff --git a/ultralytics/yolo/utils/callbacks/base.py b/ultralytics/yolo/utils/callbacks/base.py index 638f6af..0b17347 100644 --- a/ultralytics/yolo/utils/callbacks/base.py +++ b/ultralytics/yolo/utils/callbacks/base.py @@ -198,6 +198,7 @@ def add_integration_callbacks(instance): """ from .clearml import callbacks as clearml_cb from .comet import callbacks as comet_cb + from .dvc import callbacks as dvc_cb from .hub import callbacks as hub_cb from .mlflow import callbacks as mlflow_cb from .neptune import callbacks as neptune_cb @@ -205,7 +206,7 @@ def add_integration_callbacks(instance): from .tensorboard import callbacks as tensorboard_cb from .wb import callbacks as wb_cb - for x in clearml_cb, comet_cb, hub_cb, mlflow_cb, neptune_cb, tune_cb, tensorboard_cb, wb_cb: + for x in clearml_cb, comet_cb, hub_cb, mlflow_cb, neptune_cb, tune_cb, tensorboard_cb, wb_cb, dvc_cb: for k, v in x.items(): if v not in instance.callbacks[k]: # prevent duplicate callbacks addition instance.callbacks[k].append(v) # callback[name].append(func) diff --git a/ultralytics/yolo/utils/callbacks/dvc.py b/ultralytics/yolo/utils/callbacks/dvc.py new file mode 100644 index 0000000..61489aa --- /dev/null +++ b/ultralytics/yolo/utils/callbacks/dvc.py @@ -0,0 +1,135 @@ +# Ultralytics YOLO 🚀, GPL-3.0 license +import os + +from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING +from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params + +try: + from importlib.metadata import version + + import dvclive + + assert not TESTS_RUNNING # do not log pytest + assert version('dvclive') +except (ImportError, AssertionError): + dvclive = None + +# DVCLive logger instance +live = None +_processed_plots = {} + +# `on_fit_epoch_end` is called on final validation (probably need to be fixed) +# for now this is the way we distinguish final evaluation of the best model vs +# last epoch validation +_training_epoch = False + + +def _logger_disabled(): + return os.getenv('ULTRALYTICS_DVC_DISABLED', 'false').lower() == 'true' + + +def _log_images(image_path, prefix=''): + if live: + live.log_image(os.path.join(prefix, image_path.name), image_path) + + +def _log_plots(plots, prefix=''): + for name, params in plots.items(): + timestamp = params['timestamp'] + if _processed_plots.get(name, None) != timestamp: + _log_images(name, prefix) + _processed_plots[name] = timestamp + + +def _log_confusion_matrix(validator): + targets = [] + preds = [] + matrix = validator.confusion_matrix.matrix + names = list(validator.names.values()) + if validator.confusion_matrix.task == 'detect': + names += ['background'] + + for ti, pred in enumerate(matrix.T.astype(int)): + for pi, num in enumerate(pred): + targets.extend([names[ti]] * num) + preds.extend([names[pi]] * num) + + live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True) + + +def on_pretrain_routine_start(trainer): + try: + global live + if not _logger_disabled(): + live = dvclive.Live(save_dvc_exp=True) + LOGGER.info( + 'DVCLive is detected and auto logging is enabled (can be disabled with `ULTRALYTICS_DVC_DISABLED=true`).' + ) + else: + LOGGER.debug('DVCLive is detected and auto logging is disabled via `ULTRALYTICS_DVC_DISABLED`.') + live = None + except Exception as e: + LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}') + + +def on_pretrain_routine_end(trainer): + _log_plots(trainer.plots, 'train') + + +def on_train_start(trainer): + if live: + live.log_params(trainer.args) + + +def on_train_epoch_start(trainer): + global _training_epoch + _training_epoch = True + + +def on_fit_epoch_end(trainer): + global _training_epoch + if live and _training_epoch: + all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr} + for metric, value in all_metrics.items(): + live.log_metric(metric, value) + + if trainer.epoch == 0: + model_info = { + 'model/parameters': get_num_params(trainer.model), + 'model/GFLOPs': round(get_flops(trainer.model), 3), + 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} + + for metric, value in model_info.items(): + live.log_metric(metric, value, plot=False) + + _log_plots(trainer.plots, 'train') + _log_plots(trainer.validator.plots, 'val') + + live.next_step() + _training_epoch = False + + +def on_train_end(trainer): + if live: + # At the end log the best metrics. It runs validator on the best model internally. + all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr} + for metric, value in all_metrics.items(): + live.log_metric(metric, value, plot=False) + + _log_plots(trainer.plots, 'eval') + _log_plots(trainer.validator.plots, 'eval') + _log_confusion_matrix(trainer.validator) + + if trainer.best.exists(): + live.log_artifact(trainer.best, copy=True) + + live.end() + + +callbacks = { + 'on_pretrain_routine_start': on_pretrain_routine_start, + 'on_pretrain_routine_end': on_pretrain_routine_end, + 'on_train_start': on_train_start, + 'on_train_epoch_start': on_train_epoch_start, + 'on_fit_epoch_end': on_fit_epoch_end, + 'on_train_end': on_train_end} if dvclive else {}