From 6c65934b555e64bf26edd699865754b5ff651d0c Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 1 Jun 2023 04:03:06 +0530 Subject: [PATCH] W&B updates (#2895) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/utils/callbacks/wb.py | 19 +++++++++++++++---- ultralytics/yolo/utils/metrics.py | 10 ++++++++-- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/ultralytics/yolo/utils/callbacks/wb.py b/ultralytics/yolo/utils/callbacks/wb.py index 2b3d40d..827f797 100644 --- a/ultralytics/yolo/utils/callbacks/wb.py +++ b/ultralytics/yolo/utils/callbacks/wb.py @@ -1,5 +1,4 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license - from ultralytics.yolo.utils import TESTS_RUNNING from ultralytics.yolo.utils.torch_utils import model_info_for_loggers @@ -11,6 +10,16 @@ try: except (ImportError, AssertionError): wb = None +_processed_plots = {} + + +def _log_plots(plots, step): + for name, params in plots.items(): + timestamp = params['timestamp'] + if _processed_plots.get(name, None) != timestamp: + wb.run.log({name.stem: wb.Image(str(name))}, step=step) + _processed_plots[name] = timestamp + def on_pretrain_routine_start(trainer): """Initiate and start project if module is present.""" @@ -20,6 +29,8 @@ def on_pretrain_routine_start(trainer): def on_fit_epoch_end(trainer): """Logs training metrics and model information at the end of an epoch.""" wb.run.log(trainer.metrics, step=trainer.epoch + 1) + _log_plots(trainer.plots, step=trainer.epoch + 1) + _log_plots(trainer.validator.plots, step=trainer.epoch + 1) if trainer.epoch == 0: wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1) @@ -29,13 +40,13 @@ def on_train_epoch_end(trainer): wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1) wb.run.log(trainer.lr, step=trainer.epoch + 1) if trainer.epoch == 1: - wb.run.log({f.stem: wb.Image(str(f)) - for f in trainer.save_dir.glob('train_batch*.jpg')}, - step=trainer.epoch + 1) + _log_plots(trainer.plots, step=trainer.epoch + 1) def on_train_end(trainer): """Save the best model as an artifact at end of training.""" + _log_plots(trainer.validator.plots, step=trainer.epoch + 1) + _log_plots(trainer.plots, step=trainer.epoch + 1) art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model') if trainer.best.exists(): art.add_file(trainer.best) diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py index 8cb4580..30899e4 100644 --- a/ultralytics/yolo/utils/metrics.py +++ b/ultralytics/yolo/utils/metrics.py @@ -707,8 +707,14 @@ class DetMetrics(SimpleClass): def process(self, tp, conf, pred_cls, target_cls): """Process predicted results for object detection and update metrics.""" - results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir, - names=self.names)[2:] + results = ap_per_class(tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + save_dir=self.save_dir, + names=self.names, + on_plot=self.on_plot)[2:] self.box.nc = len(self.names) self.box.update(results)