From 9a2f67b3b46c961aee10e35dd8e3ab3b3a01064e Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Mon, 9 Jan 2023 04:57:14 +0530 Subject: [PATCH] Log lr for param groups (#159) --- ultralytics/yolo/engine/trainer.py | 5 +++-- ultralytics/yolo/utils/callbacks/wb.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 04d101a..9f89804 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -317,7 +317,8 @@ class BaseTrainer: self.run_callbacks("on_train_batch_end") - lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers + self.lr = {f"lr/pg{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers + self.scheduler.step() self.run_callbacks("on_train_epoch_end") @@ -328,7 +329,7 @@ class BaseTrainer: final_epoch = (epoch + 1 == self.epochs) if self.args.val or final_epoch: self.metrics, self.fitness = self.validate() - self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **lr}) + self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr}) # Save model if self.args.save or (epoch + 1 == self.epochs): diff --git a/ultralytics/yolo/utils/callbacks/wb.py b/ultralytics/yolo/utils/callbacks/wb.py index e3d6d21..5ebea51 100644 --- a/ultralytics/yolo/utils/callbacks/wb.py +++ b/ultralytics/yolo/utils/callbacks/wb.py @@ -25,6 +25,7 @@ def on_fit_epoch_end(trainer): def on_train_epoch_end(trainer): wandb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1) + wandb.run.log(trainer.lr, step=trainer.epoch + 1) if trainer.epoch == 1: wandb.run.log({f.stem: wandb.Image(str(f)) for f in trainer.save_dir.glob('train_batch*.jpg')},