diff --git a/ultralytics/yolo/utils/callbacks/base.py b/ultralytics/yolo/utils/callbacks/base.py index 671c91f..d91d255 100644 --- a/ultralytics/yolo/utils/callbacks/base.py +++ b/ultralytics/yolo/utils/callbacks/base.py @@ -99,7 +99,8 @@ default_callbacks = { def add_integration_callbacks(trainer): from .clearml import callbacks as clearml_callbacks from .tb import callbacks as tb_callbacks + from .wb import callbacks as wb_callbacks - for x in tb_callbacks, clearml_callbacks: + for x in clearml_callbacks, tb_callbacks, wb_callbacks: for k, v in x.items(): trainer.add_callback(k, v) # add_callback(name, func) diff --git a/ultralytics/yolo/utils/callbacks/clearml.py b/ultralytics/yolo/utils/callbacks/clearml.py index 3cfd4b1..8c01cbd 100644 --- a/ultralytics/yolo/utils/callbacks/clearml.py +++ b/ultralytics/yolo/utils/callbacks/clearml.py @@ -16,7 +16,7 @@ def _log_images(imgs_dict, group="", step=0): task.get_logger().report_image(group, k, step, v) -def on_train_start(trainer): +def on_pretrain_routine_start(trainer): # TODO: reuse existing task task = Task.init(project_name=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8', task_name=trainer.args.name, @@ -48,7 +48,7 @@ def on_train_end(trainer): callbacks = { - "on_train_start": on_train_start, + "on_pretrain_routine_start": on_pretrain_routine_start, "on_train_epoch_end": on_train_epoch_end, "on_val_end": on_val_end, "on_train_end": on_train_end} if clearml else {} diff --git a/ultralytics/yolo/utils/callbacks/tb.py b/ultralytics/yolo/utils/callbacks/tb.py index b442424..294112a 100644 --- a/ultralytics/yolo/utils/callbacks/tb.py +++ b/ultralytics/yolo/utils/callbacks/tb.py @@ -8,7 +8,7 @@ def _log_scalars(scalars, step=0): writer.add_scalar(k, v, step) -def on_train_start(trainer): +def on_pretrain_routine_start(trainer): global writer writer = SummaryWriter(str(trainer.save_dir)) @@ -21,4 +21,7 @@ def on_val_end(trainer): _log_scalars(trainer.metrics, trainer.epoch + 1) -callbacks = {"on_train_start": on_train_start, "on_val_end": on_val_end, "on_batch_end": on_batch_end} +callbacks = { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_val_end": on_val_end, + "on_batch_end": on_batch_end} diff --git a/ultralytics/yolo/utils/callbacks/wb.py b/ultralytics/yolo/utils/callbacks/wb.py new file mode 100644 index 0000000..d0287d1 --- /dev/null +++ b/ultralytics/yolo/utils/callbacks/wb.py @@ -0,0 +1,46 @@ +from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params + +try: + import wandb + + assert hasattr(wandb, '__version__') +except (ImportError, AssertionError): + wandb = None + + +def on_pretrain_routine_start(trainer): + wandb.init(project=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8', + name=trainer.args.name, + config=dict(trainer.args)) if not wandb.run else wandb.run + + +def on_val_end(trainer): + wandb.run.log(trainer.metrics, step=trainer.epoch + 1) + if trainer.epoch == 0: + model_info = { + "model/parameters": get_num_params(trainer.model), + "model/GFLOPs": round(get_flops(trainer.model), 1), + "model/speed(ms)": round(trainer.validator.speed[1], 1)} + wandb.run.log(model_info, step=trainer.epoch + 1) + + +def on_train_epoch_end(trainer): + wandb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1) + if trainer.epoch == 1: + wandb.run.log({f.stem: wandb.Image(str(f)) + for f in trainer.save_dir.glob('train_batch*.jpg')}, + step=trainer.epoch + 1) + + +def on_train_end(trainer): + art = wandb.Artifact(type="model", name=f"run_{wandb.run.id}_model") + if trainer.best.exists(): + art.add_file(trainer.best) + wandb.run.log_artifact(art) + + +callbacks = { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_val_end": on_val_end, + "on_train_end": on_train_end} if wandb else {} diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index a61c669..573a6e9 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -55,7 +55,7 @@ def DDP_model(model): return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) -def select_device(device='', batch_size=0, newline=True): +def select_device(device='', batch_size=0, newline=False): # device = None or 'cpu' or 0 or '0' or '0,1,2,3' ver = git_describe() or ultralytics.__version__ # git commit or pip package version s = f'Ultralytics YOLO 🚀 {ver} Python-{platform.python_version()} torch-{torch.__version__} ' @@ -86,9 +86,7 @@ def select_device(device='', batch_size=0, newline=True): s += 'CPU\n' arg = 'cpu' - if not newline: - s = s.rstrip() - LOGGER.info(s) + LOGGER.info(s if newline else s.rstrip()) return torch.device(arg) @@ -150,6 +148,7 @@ def get_num_gradients(model): def get_flops(model, imgsz=640): try: + model = de_parallel(model) p = next(model.parameters()) stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format