Add pred, export and val callbacks (#126)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 63c7a74691
commit c6eb6720de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -56,6 +56,7 @@ import re
import subprocess
import time
import warnings
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
@ -71,6 +72,7 @@ from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
from ultralytics.yolo.data.utils import check_dataset
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, get_default_args, yaml_save
from ultralytics.yolo.utils.callbacks import default_callbacks
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
from ultralytics.yolo.utils.files import file_size, increment_path
from ultralytics.yolo.utils.ops import Profile
@ -142,8 +144,14 @@ class Exporter:
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
self.save_dir.mkdir(parents=True, exist_ok=True)
# callbacks
self.callbacks = defaultdict([])
for callback, func in default_callbacks.items():
self.add_callback(callback, func)
@smart_inference_mode()
def __call__(self, model=None):
self.run_callbacks("on_export_start")
t = time.time()
format = self.args.format.lower() # to lowercase
fmts = tuple(export_formats()['Argument'][1:]) # available export formats
@ -245,6 +253,8 @@ class Exporter:
f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}"
f"\nValidate: yolo task={task} mode=val model={f[-1]} {s}"
f"\nVisualize: https://netron.app")
self.run_callbacks("on_export_end")
return f # return list of exported files/dirs
@try_export
@ -755,6 +765,22 @@ class Exporter:
LOGGER.info(f'{prefix} pipeline success')
return model
def add_callback(self, event: str, callback):
"""
appends the given callback
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
overrides the existing callbacks with the given callback
"""
self.callbacks[event] = [callback]
def run_callbacks(self, event: str):
for callback in self.callbacks.get(event, []):
callback(self)
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def export(cfg):

@ -26,6 +26,7 @@ Usage - formats:
yolov8n_paddle_model # PaddlePaddle
"""
import platform
from collections import defaultdict
from pathlib import Path
import cv2
@ -35,6 +36,7 @@ from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops
from ultralytics.yolo.utils.callbacks import default_callbacks
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
@ -89,6 +91,11 @@ class BasePredictor:
self.annotator = None
self.data_path = None
# callbacks
self.callbacks = defaultdict([])
for callback, func in default_callbacks.items():
self.add_callback(callback, func)
def preprocess(self, img):
pass
@ -143,9 +150,11 @@ class BasePredictor:
@smart_inference_mode()
def __call__(self, source=None, model=None):
self.run_callbacks("on_predict_start")
model = self.model if self.done_setup else self.setup(source, model)
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
for batch in self.dataset:
self.run_callbacks("on_predict_batch_start")
path, im, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
with self.dt[0]:
@ -176,6 +185,8 @@ class BasePredictor:
# Print time (inference-only)
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
self.run_callbacks("on_predict_batch_end")
# Print results
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
LOGGER.info(
@ -185,6 +196,8 @@ class BasePredictor:
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks("on_predict_end")
def show(self, p):
im0 = self.annotator.result()
if platform.system() == 'Linux' and p not in self.windows:
@ -213,3 +226,19 @@ class BasePredictor:
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
self.vid_writer[idx].write(im0)
def add_callback(self, event: str, callback):
"""
appends the given callback
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
overrides the existing callbacks with the given callback
"""
self.callbacks[event] = [callback]
def run_callbacks(self, event: str):
for callback in self.callbacks.get(event, []):
callback(self)

@ -136,20 +136,20 @@ class BaseTrainer:
if RANK in {0, -1}:
callbacks.add_integration_callbacks(self)
def add_callback(self, onevent: str, callback):
def add_callback(self, event: str, callback):
"""
appends the given callback
"""
self.callbacks[onevent].append(callback)
self.callbacks[event].append(callback)
def set_callback(self, onevent: str, callback):
def set_callback(self, event: str, callback):
"""
overrides the existing callbacks with the given callback
"""
self.callbacks[onevent] = [callback]
self.callbacks[event] = [callback]
def trigger_callbacks(self, onevent: str):
for callback in self.callbacks.get(onevent, []):
def run_callbacks(self, event: str):
for callback in self.callbacks.get(event, []):
callback(self)
def train(self):
@ -178,7 +178,7 @@ class BaseTrainer:
Builds dataloaders and optimizer on correct rank process
"""
# model
self.trigger_callbacks("on_pretrain_routine_start")
self.run_callbacks("on_pretrain_routine_start")
ckpt = self.setup_model()
self.model = self.model.to(self.device)
self.set_model_attributes()
@ -210,7 +210,7 @@ class BaseTrainer:
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model)
self.trigger_callbacks("on_pretrain_routine_end")
self.run_callbacks("on_pretrain_routine_end")
def _do_train(self, rank=-1, world_size=1):
if world_size > 1:
@ -224,14 +224,14 @@ class BaseTrainer:
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1
self.trigger_callbacks("on_train_start")
self.run_callbacks("on_train_start")
self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n"
f"Starting training for {self.epochs} epochs...")
for epoch in range(self.start_epoch, self.epochs):
self.epoch = epoch
self.trigger_callbacks("on_train_epoch_start")
self.run_callbacks("on_train_epoch_start")
self.model.train()
if rank != -1:
self.train_loader.sampler.set_epoch(epoch)
@ -242,7 +242,7 @@ class BaseTrainer:
self.tloss = None
self.optimizer.zero_grad()
for i, batch in pbar:
self.trigger_callbacks("on_train_batch_start")
self.run_callbacks("on_train_batch_start")
# Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'):
@ -287,35 +287,34 @@ class BaseTrainer:
pbar.set_description(
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
self.trigger_callbacks('on_batch_end')
self.run_callbacks('on_batch_end')
if self.args.plots and ni < 3:
self.plot_training_samples(batch, ni)
self.trigger_callbacks("on_train_batch_end")
self.run_callbacks("on_train_batch_end")
lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.scheduler.step()
self.trigger_callbacks("on_train_epoch_end")
self.run_callbacks("on_train_epoch_end")
if rank in {-1, 0}:
# Validation
self.trigger_callbacks('on_val_start')
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == self.epochs)
if self.args.val or final_epoch:
self.metrics, self.fitness = self.validate()
self.trigger_callbacks('on_val_end')
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **lr})
# Save model
if self.args.save or (epoch + 1 == self.epochs):
self.save_model()
self.trigger_callbacks('on_model_save')
self.run_callbacks('on_model_save')
tnow = time.time()
self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow
self.run_callbacks("on_fit_epoch_end")
# TODO: termination condition
if rank in {-1, 0}:
@ -326,9 +325,9 @@ class BaseTrainer:
if self.args.plots:
self.plot_metrics()
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
self.trigger_callbacks('on_train_end')
self.run_callbacks('on_train_end')
torch.cuda.empty_cache()
self.trigger_callbacks('teardown')
self.run_callbacks('teardown')
def save_model(self):
ckpt = {
@ -470,7 +469,7 @@ class BaseTrainer:
self.validator.args.save_json = True
self.metrics = self.validator(model=f)
self.metrics.pop('fitness', None)
self.trigger_callbacks('on_val_end')
self.run_callbacks('on_val_end')
def check_resume(self):
resume = self.args.resume

@ -1,4 +1,5 @@
import json
from collections import defaultdict
from pathlib import Path
import torch
@ -8,6 +9,7 @@ from tqdm import tqdm
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.callbacks import default_callbacks
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile
@ -64,12 +66,18 @@ class BaseValidator:
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# callbacks
self.callbacks = defaultdict(list)
for callback, func in default_callbacks.items():
self.add_callback(callback, func)
@smart_inference_mode()
def __call__(self, trainer=None, model=None):
"""
Supports validation of a pre-trained model if passed or a model being trained
if trainer is passed (trainer gets priority).
"""
self.run_callbacks('on_val_start')
self.training = trainer is not None
if self.training:
self.device = trainer.device
@ -116,6 +124,7 @@ class BaseValidator:
self.init_metrics(de_parallel(model))
self.jdict = [] # empty before each val
for batch_i, batch in enumerate(bar):
self.run_callbacks('on_val_batch_start')
self.batch_i = batch_i
# pre-process
with dt[0]:
@ -139,10 +148,12 @@ class BaseValidator:
self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i)
self.run_callbacks('on_val_batch_end')
stats = self.get_stats()
self.check_stats(stats)
self.print_results()
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
self.run_callbacks('on_val_end')
if self.training:
model.float()
return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
@ -156,6 +167,22 @@ class BaseValidator:
stats = self.eval_json(stats) # update stats
return stats
def add_callback(self, event: str, callback):
"""
appends the given callback
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
overrides the existing callbacks with the given callback
"""
self.callbacks[event] = [callback]
def run_callbacks(self, event: str):
for callback in self.callbacks.get(event, []):
callback(self)
def get_dataloader(self, dataset_path, batch_size):
raise NotImplementedError("get_dataloader function not implemented for this validator")

@ -1,3 +1,7 @@
# Ultralytics YOLO base callbacks
# Trainer callbacks ----------------------------------------------------------------------------------------------------
def on_pretrain_routine_start(trainer):
pass
@ -34,47 +38,71 @@ def on_train_epoch_end(trainer):
pass
def on_val_start(trainer):
def on_fit_epoch_end(trainer):
pass
def on_val_batch_start(trainer):
def on_model_save(trainer):
pass
def on_val_image_end(trainer):
def on_train_end(trainer):
pass
def on_val_batch_end(trainer):
def on_params_update(trainer):
pass
def on_val_end(trainer):
def teardown(trainer):
pass
def on_fit_epoch_end(trainer):
# Validator callbacks --------------------------------------------------------------------------------------------------
def on_val_start(validator):
pass
def on_model_save(trainer):
def on_val_batch_start(validator):
pass
def on_train_end(trainer):
def on_val_batch_end(validator):
pass
def on_params_update(trainer):
def on_val_end(validator):
pass
def teardown(trainer):
# Predictor callbacks --------------------------------------------------------------------------------------------------
def on_predict_start(predictor):
pass
def on_predict_batch_start(predictor):
pass
def on_predict_batch_end(predictor):
pass
def on_predict_end(predictor):
pass
# Exporter callbacks ---------------------------------------------------------------------------------------------------
def on_export_start(exporter):
pass
def on_export_end(exporter):
pass
default_callbacks = {
# Run in trainer
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_train_start': on_train_start,
@ -84,16 +112,27 @@ default_callbacks = {
'on_before_zero_grad': on_before_zero_grad,
'on_train_batch_end': on_train_batch_end,
'on_train_epoch_end': on_train_epoch_end,
'on_val_start': on_val_start,
'on_val_batch_start': on_val_batch_start,
'on_val_image_end': on_val_image_end,
'on_val_batch_end': on_val_batch_end,
'on_val_end': on_val_end,
'on_fit_epoch_end': on_fit_epoch_end, # fit = train + val
'on_model_save': on_model_save,
'on_train_end': on_train_end,
'on_params_update': on_params_update,
'teardown': teardown}
'teardown': teardown,
# Run in validator
'on_val_start': on_val_start,
'on_val_batch_start': on_val_batch_start,
'on_val_batch_end': on_val_batch_end,
'on_val_end': on_val_end,
# Run in predictor
'on_predict_start': on_predict_start,
'on_predict_batch_start': on_predict_batch_start,
'on_predict_batch_end': on_predict_batch_end,
'on_predict_end': on_predict_end,
# Run in exporter
'on_export_start': on_export_start,
'on_export_end': on_export_end}
def add_integration_callbacks(trainer):

@ -18,7 +18,7 @@ def _log_images(imgs_dict, group="", step=0):
def on_pretrain_routine_start(trainer):
# TODO: reuse existing task
task = Task.init(project_name=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8',
task = Task.init(project_name=trainer.args.project or "YOLOv8",
task_name=trainer.args.name,
tags=['YOLOv8'],
output_uri=True,
@ -32,7 +32,7 @@ def on_train_epoch_end(trainer):
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch)
def on_val_end(trainer):
def on_fit_epoch_end(trainer):
if trainer.epoch == 0:
model_info = {
"Parameters": get_num_params(trainer.model),
@ -50,5 +50,5 @@ def on_train_end(trainer):
callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_val_end": on_val_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end} if clearml else {}

@ -17,11 +17,11 @@ def on_batch_end(trainer):
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
def on_val_end(trainer):
def on_fit_epoch_end(trainer):
_log_scalars(trainer.metrics, trainer.epoch + 1)
callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_val_end": on_val_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_batch_end": on_batch_end}

@ -9,12 +9,11 @@ except (ImportError, AssertionError):
def on_pretrain_routine_start(trainer):
wandb.init(project=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8',
name=trainer.args.name,
config=dict(trainer.args)) if not wandb.run else wandb.run
wandb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=dict(
trainer.args)) if not wandb.run else wandb.run
def on_val_end(trainer):
def on_fit_epoch_end(trainer):
wandb.run.log(trainer.metrics, step=trainer.epoch + 1)
if trainer.epoch == 0:
model_info = {
@ -42,5 +41,5 @@ def on_train_end(trainer):
callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_val_end": on_val_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end} if wandb else {}

Loading…
Cancel
Save