diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index 0f2fab2..efd1252 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -56,6 +56,7 @@ import re import subprocess import time import warnings +from collections import defaultdict from copy import deepcopy from pathlib import Path @@ -71,6 +72,7 @@ from ultralytics.yolo.configs import get_config from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages from ultralytics.yolo.data.utils import check_dataset from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, get_default_args, yaml_save +from ultralytics.yolo.utils.callbacks import default_callbacks from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml from ultralytics.yolo.utils.files import file_size, increment_path from ultralytics.yolo.utils.ops import Profile @@ -142,8 +144,14 @@ class Exporter: self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok) self.save_dir.mkdir(parents=True, exist_ok=True) + # callbacks + self.callbacks = defaultdict([]) + for callback, func in default_callbacks.items(): + self.add_callback(callback, func) + @smart_inference_mode() def __call__(self, model=None): + self.run_callbacks("on_export_start") t = time.time() format = self.args.format.lower() # to lowercase fmts = tuple(export_formats()['Argument'][1:]) # available export formats @@ -245,6 +253,8 @@ class Exporter: f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}" f"\nValidate: yolo task={task} mode=val model={f[-1]} {s}" f"\nVisualize: https://netron.app") + + self.run_callbacks("on_export_end") return f # return list of exported files/dirs @try_export @@ -755,6 +765,22 @@ class Exporter: LOGGER.info(f'{prefix} pipeline success') return model + def add_callback(self, event: str, callback): + """ + appends the given callback + """ + self.callbacks[event].append(callback) + + def set_callback(self, event: str, callback): + """ + overrides the existing callbacks with the given callback + """ + self.callbacks[event] = [callback] + + def run_callbacks(self, event: str): + for callback in self.callbacks.get(event, []): + callback(self) + @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def export(cfg): diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 0792dc7..8dd6851 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -26,6 +26,7 @@ Usage - formats: yolov8n_paddle_model # PaddlePaddle """ import platform +from collections import defaultdict from pathlib import Path import cv2 @@ -35,6 +36,7 @@ from ultralytics.yolo.configs import get_config from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops +from ultralytics.yolo.utils.callbacks import default_callbacks from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode @@ -89,6 +91,11 @@ class BasePredictor: self.annotator = None self.data_path = None + # callbacks + self.callbacks = defaultdict([]) + for callback, func in default_callbacks.items(): + self.add_callback(callback, func) + def preprocess(self, img): pass @@ -143,9 +150,11 @@ class BasePredictor: @smart_inference_mode() def __call__(self, source=None, model=None): + self.run_callbacks("on_predict_start") model = self.model if self.done_setup else self.setup(source, model) self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()) for batch in self.dataset: + self.run_callbacks("on_predict_batch_start") path, im, im0s, vid_cap, s = batch visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False with self.dt[0]: @@ -176,6 +185,8 @@ class BasePredictor: # Print time (inference-only) LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms") + self.run_callbacks("on_predict_batch_end") + # Print results t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image LOGGER.info( @@ -185,6 +196,8 @@ class BasePredictor: s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else '' LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") + self.run_callbacks("on_predict_end") + def show(self, p): im0 = self.annotator.result() if platform.system() == 'Linux' and p not in self.windows: @@ -213,3 +226,19 @@ class BasePredictor: save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) self.vid_writer[idx].write(im0) + + def add_callback(self, event: str, callback): + """ + appends the given callback + """ + self.callbacks[event].append(callback) + + def set_callback(self, event: str, callback): + """ + overrides the existing callbacks with the given callback + """ + self.callbacks[event] = [callback] + + def run_callbacks(self, event: str): + for callback in self.callbacks.get(event, []): + callback(self) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index a126847..4c7915d 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -136,20 +136,20 @@ class BaseTrainer: if RANK in {0, -1}: callbacks.add_integration_callbacks(self) - def add_callback(self, onevent: str, callback): + def add_callback(self, event: str, callback): """ appends the given callback """ - self.callbacks[onevent].append(callback) + self.callbacks[event].append(callback) - def set_callback(self, onevent: str, callback): + def set_callback(self, event: str, callback): """ overrides the existing callbacks with the given callback """ - self.callbacks[onevent] = [callback] + self.callbacks[event] = [callback] - def trigger_callbacks(self, onevent: str): - for callback in self.callbacks.get(onevent, []): + def run_callbacks(self, event: str): + for callback in self.callbacks.get(event, []): callback(self) def train(self): @@ -178,7 +178,7 @@ class BaseTrainer: Builds dataloaders and optimizer on correct rank process """ # model - self.trigger_callbacks("on_pretrain_routine_start") + self.run_callbacks("on_pretrain_routine_start") ckpt = self.setup_model() self.model = self.model.to(self.device) self.set_model_attributes() @@ -210,7 +210,7 @@ class BaseTrainer: metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val") self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()? self.ema = ModelEMA(self.model) - self.trigger_callbacks("on_pretrain_routine_end") + self.run_callbacks("on_pretrain_routine_end") def _do_train(self, rank=-1, world_size=1): if world_size > 1: @@ -224,14 +224,14 @@ class BaseTrainer: nb = len(self.train_loader) # number of batches nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations last_opt_step = -1 - self.trigger_callbacks("on_train_start") + self.run_callbacks("on_train_start") self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n" f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' f"Logging results to {colorstr('bold', self.save_dir)}\n" f"Starting training for {self.epochs} epochs...") for epoch in range(self.start_epoch, self.epochs): self.epoch = epoch - self.trigger_callbacks("on_train_epoch_start") + self.run_callbacks("on_train_epoch_start") self.model.train() if rank != -1: self.train_loader.sampler.set_epoch(epoch) @@ -242,7 +242,7 @@ class BaseTrainer: self.tloss = None self.optimizer.zero_grad() for i, batch in pbar: - self.trigger_callbacks("on_train_batch_start") + self.run_callbacks("on_train_batch_start") # Update dataloader attributes (optional) if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'): @@ -287,35 +287,34 @@ class BaseTrainer: pbar.set_description( ('%11s' * 2 + '%11.4g' * (2 + loss_len)) % (f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])) - self.trigger_callbacks('on_batch_end') + self.run_callbacks('on_batch_end') if self.args.plots and ni < 3: self.plot_training_samples(batch, ni) - self.trigger_callbacks("on_train_batch_end") + self.run_callbacks("on_train_batch_end") lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers self.scheduler.step() - self.trigger_callbacks("on_train_epoch_end") + self.run_callbacks("on_train_epoch_end") if rank in {-1, 0}: + # Validation - self.trigger_callbacks('on_val_start') self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) final_epoch = (epoch + 1 == self.epochs) if self.args.val or final_epoch: self.metrics, self.fitness = self.validate() - self.trigger_callbacks('on_val_end') self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **lr}) # Save model if self.args.save or (epoch + 1 == self.epochs): self.save_model() - self.trigger_callbacks('on_model_save') + self.run_callbacks('on_model_save') tnow = time.time() self.epoch_time = tnow - self.epoch_time_start self.epoch_time_start = tnow - + self.run_callbacks("on_fit_epoch_end") # TODO: termination condition if rank in {-1, 0}: @@ -326,9 +325,9 @@ class BaseTrainer: if self.args.plots: self.plot_metrics() self.log(f"Results saved to {colorstr('bold', self.save_dir)}") - self.trigger_callbacks('on_train_end') + self.run_callbacks('on_train_end') torch.cuda.empty_cache() - self.trigger_callbacks('teardown') + self.run_callbacks('teardown') def save_model(self): ckpt = { @@ -470,7 +469,7 @@ class BaseTrainer: self.validator.args.save_json = True self.metrics = self.validator(model=f) self.metrics.pop('fitness', None) - self.trigger_callbacks('on_val_end') + self.run_callbacks('on_val_end') def check_resume(self): resume = self.args.resume diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 7f6d8b3..dde017b 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -1,4 +1,5 @@ import json +from collections import defaultdict from pathlib import Path import torch @@ -8,6 +9,7 @@ from tqdm import tqdm from ultralytics.nn.autobackend import AutoBackend from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT +from ultralytics.yolo.utils.callbacks import default_callbacks from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.ops import Profile @@ -64,12 +66,18 @@ class BaseValidator: exist_ok=self.args.exist_ok if RANK in {-1, 0} else True) (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + # callbacks + self.callbacks = defaultdict(list) + for callback, func in default_callbacks.items(): + self.add_callback(callback, func) + @smart_inference_mode() def __call__(self, trainer=None, model=None): """ Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer gets priority). """ + self.run_callbacks('on_val_start') self.training = trainer is not None if self.training: self.device = trainer.device @@ -116,6 +124,7 @@ class BaseValidator: self.init_metrics(de_parallel(model)) self.jdict = [] # empty before each val for batch_i, batch in enumerate(bar): + self.run_callbacks('on_val_batch_start') self.batch_i = batch_i # pre-process with dt[0]: @@ -139,10 +148,12 @@ class BaseValidator: self.plot_val_samples(batch, batch_i) self.plot_predictions(batch, preds, batch_i) + self.run_callbacks('on_val_batch_end') stats = self.get_stats() self.check_stats(stats) self.print_results() self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image + self.run_callbacks('on_val_end') if self.training: model.float() return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} @@ -156,6 +167,22 @@ class BaseValidator: stats = self.eval_json(stats) # update stats return stats + def add_callback(self, event: str, callback): + """ + appends the given callback + """ + self.callbacks[event].append(callback) + + def set_callback(self, event: str, callback): + """ + overrides the existing callbacks with the given callback + """ + self.callbacks[event] = [callback] + + def run_callbacks(self, event: str): + for callback in self.callbacks.get(event, []): + callback(self) + def get_dataloader(self, dataset_path, batch_size): raise NotImplementedError("get_dataloader function not implemented for this validator") diff --git a/ultralytics/yolo/utils/callbacks/base.py b/ultralytics/yolo/utils/callbacks/base.py index d91d255..7cdb114 100644 --- a/ultralytics/yolo/utils/callbacks/base.py +++ b/ultralytics/yolo/utils/callbacks/base.py @@ -1,3 +1,7 @@ +# Ultralytics YOLO base callbacks + + +# Trainer callbacks ---------------------------------------------------------------------------------------------------- def on_pretrain_routine_start(trainer): pass @@ -34,47 +38,71 @@ def on_train_epoch_end(trainer): pass -def on_val_start(trainer): +def on_fit_epoch_end(trainer): pass -def on_val_batch_start(trainer): +def on_model_save(trainer): pass -def on_val_image_end(trainer): +def on_train_end(trainer): pass -def on_val_batch_end(trainer): +def on_params_update(trainer): pass -def on_val_end(trainer): +def teardown(trainer): pass -def on_fit_epoch_end(trainer): +# Validator callbacks -------------------------------------------------------------------------------------------------- +def on_val_start(validator): pass -def on_model_save(trainer): +def on_val_batch_start(validator): pass -def on_train_end(trainer): +def on_val_batch_end(validator): pass -def on_params_update(trainer): +def on_val_end(validator): pass -def teardown(trainer): +# Predictor callbacks -------------------------------------------------------------------------------------------------- +def on_predict_start(predictor): + pass + + +def on_predict_batch_start(predictor): + pass + + +def on_predict_batch_end(predictor): + pass + + +def on_predict_end(predictor): + pass + + +# Exporter callbacks --------------------------------------------------------------------------------------------------- +def on_export_start(exporter): + pass + + +def on_export_end(exporter): pass default_callbacks = { + # Run in trainer 'on_pretrain_routine_start': on_pretrain_routine_start, 'on_pretrain_routine_end': on_pretrain_routine_end, 'on_train_start': on_train_start, @@ -84,16 +112,27 @@ default_callbacks = { 'on_before_zero_grad': on_before_zero_grad, 'on_train_batch_end': on_train_batch_end, 'on_train_epoch_end': on_train_epoch_end, - 'on_val_start': on_val_start, - 'on_val_batch_start': on_val_batch_start, - 'on_val_image_end': on_val_image_end, - 'on_val_batch_end': on_val_batch_end, - 'on_val_end': on_val_end, 'on_fit_epoch_end': on_fit_epoch_end, # fit = train + val 'on_model_save': on_model_save, 'on_train_end': on_train_end, 'on_params_update': on_params_update, - 'teardown': teardown} + 'teardown': teardown, + + # Run in validator + 'on_val_start': on_val_start, + 'on_val_batch_start': on_val_batch_start, + 'on_val_batch_end': on_val_batch_end, + 'on_val_end': on_val_end, + + # Run in predictor + 'on_predict_start': on_predict_start, + 'on_predict_batch_start': on_predict_batch_start, + 'on_predict_batch_end': on_predict_batch_end, + 'on_predict_end': on_predict_end, + + # Run in exporter + 'on_export_start': on_export_start, + 'on_export_end': on_export_end} def add_integration_callbacks(trainer): diff --git a/ultralytics/yolo/utils/callbacks/clearml.py b/ultralytics/yolo/utils/callbacks/clearml.py index 225a25f..defd320 100644 --- a/ultralytics/yolo/utils/callbacks/clearml.py +++ b/ultralytics/yolo/utils/callbacks/clearml.py @@ -18,7 +18,7 @@ def _log_images(imgs_dict, group="", step=0): def on_pretrain_routine_start(trainer): # TODO: reuse existing task - task = Task.init(project_name=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8', + task = Task.init(project_name=trainer.args.project or "YOLOv8", task_name=trainer.args.name, tags=['YOLOv8'], output_uri=True, @@ -32,7 +32,7 @@ def on_train_epoch_end(trainer): _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch) -def on_val_end(trainer): +def on_fit_epoch_end(trainer): if trainer.epoch == 0: model_info = { "Parameters": get_num_params(trainer.model), @@ -50,5 +50,5 @@ def on_train_end(trainer): callbacks = { "on_pretrain_routine_start": on_pretrain_routine_start, "on_train_epoch_end": on_train_epoch_end, - "on_val_end": on_val_end, + "on_fit_epoch_end": on_fit_epoch_end, "on_train_end": on_train_end} if clearml else {} diff --git a/ultralytics/yolo/utils/callbacks/tb.py b/ultralytics/yolo/utils/callbacks/tb.py index 294112a..1093c92 100644 --- a/ultralytics/yolo/utils/callbacks/tb.py +++ b/ultralytics/yolo/utils/callbacks/tb.py @@ -17,11 +17,11 @@ def on_batch_end(trainer): _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) -def on_val_end(trainer): +def on_fit_epoch_end(trainer): _log_scalars(trainer.metrics, trainer.epoch + 1) callbacks = { "on_pretrain_routine_start": on_pretrain_routine_start, - "on_val_end": on_val_end, + "on_fit_epoch_end": on_fit_epoch_end, "on_batch_end": on_batch_end} diff --git a/ultralytics/yolo/utils/callbacks/wb.py b/ultralytics/yolo/utils/callbacks/wb.py index eed8ead..e3d6d21 100644 --- a/ultralytics/yolo/utils/callbacks/wb.py +++ b/ultralytics/yolo/utils/callbacks/wb.py @@ -9,12 +9,11 @@ except (ImportError, AssertionError): def on_pretrain_routine_start(trainer): - wandb.init(project=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8', - name=trainer.args.name, - config=dict(trainer.args)) if not wandb.run else wandb.run + wandb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=dict( + trainer.args)) if not wandb.run else wandb.run -def on_val_end(trainer): +def on_fit_epoch_end(trainer): wandb.run.log(trainer.metrics, step=trainer.epoch + 1) if trainer.epoch == 0: model_info = { @@ -42,5 +41,5 @@ def on_train_end(trainer): callbacks = { "on_pretrain_routine_start": on_pretrain_routine_start, "on_train_epoch_end": on_train_epoch_end, - "on_val_end": on_val_end, + "on_fit_epoch_end": on_fit_epoch_end, "on_train_end": on_train_end} if wandb else {}