Add pred, export and val callbacks (#126)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 63c7a74691
commit c6eb6720de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -56,6 +56,7 @@ import re
import subprocess import subprocess
import time import time
import warnings import warnings
from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
@ -71,6 +72,7 @@ from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
from ultralytics.yolo.data.utils import check_dataset from ultralytics.yolo.data.utils import check_dataset
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, get_default_args, yaml_save from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, get_default_args, yaml_save
from ultralytics.yolo.utils.callbacks import default_callbacks
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
from ultralytics.yolo.utils.files import file_size, increment_path from ultralytics.yolo.utils.files import file_size, increment_path
from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.ops import Profile
@ -142,8 +144,14 @@ class Exporter:
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok) self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
self.save_dir.mkdir(parents=True, exist_ok=True) self.save_dir.mkdir(parents=True, exist_ok=True)
# callbacks
self.callbacks = defaultdict([])
for callback, func in default_callbacks.items():
self.add_callback(callback, func)
@smart_inference_mode() @smart_inference_mode()
def __call__(self, model=None): def __call__(self, model=None):
self.run_callbacks("on_export_start")
t = time.time() t = time.time()
format = self.args.format.lower() # to lowercase format = self.args.format.lower() # to lowercase
fmts = tuple(export_formats()['Argument'][1:]) # available export formats fmts = tuple(export_formats()['Argument'][1:]) # available export formats
@ -245,6 +253,8 @@ class Exporter:
f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}" f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}"
f"\nValidate: yolo task={task} mode=val model={f[-1]} {s}" f"\nValidate: yolo task={task} mode=val model={f[-1]} {s}"
f"\nVisualize: https://netron.app") f"\nVisualize: https://netron.app")
self.run_callbacks("on_export_end")
return f # return list of exported files/dirs return f # return list of exported files/dirs
@try_export @try_export
@ -755,6 +765,22 @@ class Exporter:
LOGGER.info(f'{prefix} pipeline success') LOGGER.info(f'{prefix} pipeline success')
return model return model
def add_callback(self, event: str, callback):
"""
appends the given callback
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
overrides the existing callbacks with the given callback
"""
self.callbacks[event] = [callback]
def run_callbacks(self, event: str):
for callback in self.callbacks.get(event, []):
callback(self)
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def export(cfg): def export(cfg):

@ -26,6 +26,7 @@ Usage - formats:
yolov8n_paddle_model # PaddlePaddle yolov8n_paddle_model # PaddlePaddle
""" """
import platform import platform
from collections import defaultdict
from pathlib import Path from pathlib import Path
import cv2 import cv2
@ -35,6 +36,7 @@ from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops
from ultralytics.yolo.utils.callbacks import default_callbacks
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
@ -89,6 +91,11 @@ class BasePredictor:
self.annotator = None self.annotator = None
self.data_path = None self.data_path = None
# callbacks
self.callbacks = defaultdict([])
for callback, func in default_callbacks.items():
self.add_callback(callback, func)
def preprocess(self, img): def preprocess(self, img):
pass pass
@ -143,9 +150,11 @@ class BasePredictor:
@smart_inference_mode() @smart_inference_mode()
def __call__(self, source=None, model=None): def __call__(self, source=None, model=None):
self.run_callbacks("on_predict_start")
model = self.model if self.done_setup else self.setup(source, model) model = self.model if self.done_setup else self.setup(source, model)
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()) self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
for batch in self.dataset: for batch in self.dataset:
self.run_callbacks("on_predict_batch_start")
path, im, im0s, vid_cap, s = batch path, im, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
with self.dt[0]: with self.dt[0]:
@ -176,6 +185,8 @@ class BasePredictor:
# Print time (inference-only) # Print time (inference-only)
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms") LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
self.run_callbacks("on_predict_batch_end")
# Print results # Print results
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
LOGGER.info( LOGGER.info(
@ -185,6 +196,8 @@ class BasePredictor:
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else '' s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks("on_predict_end")
def show(self, p): def show(self, p):
im0 = self.annotator.result() im0 = self.annotator.result()
if platform.system() == 'Linux' and p not in self.windows: if platform.system() == 'Linux' and p not in self.windows:
@ -213,3 +226,19 @@ class BasePredictor:
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
self.vid_writer[idx].write(im0) self.vid_writer[idx].write(im0)
def add_callback(self, event: str, callback):
"""
appends the given callback
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
overrides the existing callbacks with the given callback
"""
self.callbacks[event] = [callback]
def run_callbacks(self, event: str):
for callback in self.callbacks.get(event, []):
callback(self)

@ -136,20 +136,20 @@ class BaseTrainer:
if RANK in {0, -1}: if RANK in {0, -1}:
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
def add_callback(self, onevent: str, callback): def add_callback(self, event: str, callback):
""" """
appends the given callback appends the given callback
""" """
self.callbacks[onevent].append(callback) self.callbacks[event].append(callback)
def set_callback(self, onevent: str, callback): def set_callback(self, event: str, callback):
""" """
overrides the existing callbacks with the given callback overrides the existing callbacks with the given callback
""" """
self.callbacks[onevent] = [callback] self.callbacks[event] = [callback]
def trigger_callbacks(self, onevent: str): def run_callbacks(self, event: str):
for callback in self.callbacks.get(onevent, []): for callback in self.callbacks.get(event, []):
callback(self) callback(self)
def train(self): def train(self):
@ -178,7 +178,7 @@ class BaseTrainer:
Builds dataloaders and optimizer on correct rank process Builds dataloaders and optimizer on correct rank process
""" """
# model # model
self.trigger_callbacks("on_pretrain_routine_start") self.run_callbacks("on_pretrain_routine_start")
ckpt = self.setup_model() ckpt = self.setup_model()
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
self.set_model_attributes() self.set_model_attributes()
@ -210,7 +210,7 @@ class BaseTrainer:
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val") metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()? self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model) self.ema = ModelEMA(self.model)
self.trigger_callbacks("on_pretrain_routine_end") self.run_callbacks("on_pretrain_routine_end")
def _do_train(self, rank=-1, world_size=1): def _do_train(self, rank=-1, world_size=1):
if world_size > 1: if world_size > 1:
@ -224,14 +224,14 @@ class BaseTrainer:
nb = len(self.train_loader) # number of batches nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1 last_opt_step = -1
self.trigger_callbacks("on_train_start") self.run_callbacks("on_train_start")
self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n" self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n" f"Logging results to {colorstr('bold', self.save_dir)}\n"
f"Starting training for {self.epochs} epochs...") f"Starting training for {self.epochs} epochs...")
for epoch in range(self.start_epoch, self.epochs): for epoch in range(self.start_epoch, self.epochs):
self.epoch = epoch self.epoch = epoch
self.trigger_callbacks("on_train_epoch_start") self.run_callbacks("on_train_epoch_start")
self.model.train() self.model.train()
if rank != -1: if rank != -1:
self.train_loader.sampler.set_epoch(epoch) self.train_loader.sampler.set_epoch(epoch)
@ -242,7 +242,7 @@ class BaseTrainer:
self.tloss = None self.tloss = None
self.optimizer.zero_grad() self.optimizer.zero_grad()
for i, batch in pbar: for i, batch in pbar:
self.trigger_callbacks("on_train_batch_start") self.run_callbacks("on_train_batch_start")
# Update dataloader attributes (optional) # Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'): if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'):
@ -287,35 +287,34 @@ class BaseTrainer:
pbar.set_description( pbar.set_description(
('%11s' * 2 + '%11.4g' * (2 + loss_len)) % ('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])) (f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
self.trigger_callbacks('on_batch_end') self.run_callbacks('on_batch_end')
if self.args.plots and ni < 3: if self.args.plots and ni < 3:
self.plot_training_samples(batch, ni) self.plot_training_samples(batch, ni)
self.trigger_callbacks("on_train_batch_end") self.run_callbacks("on_train_batch_end")
lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.scheduler.step() self.scheduler.step()
self.trigger_callbacks("on_train_epoch_end") self.run_callbacks("on_train_epoch_end")
if rank in {-1, 0}: if rank in {-1, 0}:
# Validation # Validation
self.trigger_callbacks('on_val_start')
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == self.epochs) final_epoch = (epoch + 1 == self.epochs)
if self.args.val or final_epoch: if self.args.val or final_epoch:
self.metrics, self.fitness = self.validate() self.metrics, self.fitness = self.validate()
self.trigger_callbacks('on_val_end')
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **lr}) self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **lr})
# Save model # Save model
if self.args.save or (epoch + 1 == self.epochs): if self.args.save or (epoch + 1 == self.epochs):
self.save_model() self.save_model()
self.trigger_callbacks('on_model_save') self.run_callbacks('on_model_save')
tnow = time.time() tnow = time.time()
self.epoch_time = tnow - self.epoch_time_start self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow self.epoch_time_start = tnow
self.run_callbacks("on_fit_epoch_end")
# TODO: termination condition # TODO: termination condition
if rank in {-1, 0}: if rank in {-1, 0}:
@ -326,9 +325,9 @@ class BaseTrainer:
if self.args.plots: if self.args.plots:
self.plot_metrics() self.plot_metrics()
self.log(f"Results saved to {colorstr('bold', self.save_dir)}") self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
self.trigger_callbacks('on_train_end') self.run_callbacks('on_train_end')
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.trigger_callbacks('teardown') self.run_callbacks('teardown')
def save_model(self): def save_model(self):
ckpt = { ckpt = {
@ -470,7 +469,7 @@ class BaseTrainer:
self.validator.args.save_json = True self.validator.args.save_json = True
self.metrics = self.validator(model=f) self.metrics = self.validator(model=f)
self.metrics.pop('fitness', None) self.metrics.pop('fitness', None)
self.trigger_callbacks('on_val_end') self.run_callbacks('on_val_end')
def check_resume(self): def check_resume(self):
resume = self.args.resume resume = self.args.resume

@ -1,4 +1,5 @@
import json import json
from collections import defaultdict
from pathlib import Path from pathlib import Path
import torch import torch
@ -8,6 +9,7 @@ from tqdm import tqdm
from ultralytics.nn.autobackend import AutoBackend from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.callbacks import default_callbacks
from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.ops import Profile
@ -64,12 +66,18 @@ class BaseValidator:
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True) exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# callbacks
self.callbacks = defaultdict(list)
for callback, func in default_callbacks.items():
self.add_callback(callback, func)
@smart_inference_mode() @smart_inference_mode()
def __call__(self, trainer=None, model=None): def __call__(self, trainer=None, model=None):
""" """
Supports validation of a pre-trained model if passed or a model being trained Supports validation of a pre-trained model if passed or a model being trained
if trainer is passed (trainer gets priority). if trainer is passed (trainer gets priority).
""" """
self.run_callbacks('on_val_start')
self.training = trainer is not None self.training = trainer is not None
if self.training: if self.training:
self.device = trainer.device self.device = trainer.device
@ -116,6 +124,7 @@ class BaseValidator:
self.init_metrics(de_parallel(model)) self.init_metrics(de_parallel(model))
self.jdict = [] # empty before each val self.jdict = [] # empty before each val
for batch_i, batch in enumerate(bar): for batch_i, batch in enumerate(bar):
self.run_callbacks('on_val_batch_start')
self.batch_i = batch_i self.batch_i = batch_i
# pre-process # pre-process
with dt[0]: with dt[0]:
@ -139,10 +148,12 @@ class BaseValidator:
self.plot_val_samples(batch, batch_i) self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i) self.plot_predictions(batch, preds, batch_i)
self.run_callbacks('on_val_batch_end')
stats = self.get_stats() stats = self.get_stats()
self.check_stats(stats) self.check_stats(stats)
self.print_results() self.print_results()
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
self.run_callbacks('on_val_end')
if self.training: if self.training:
model.float() model.float()
return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
@ -156,6 +167,22 @@ class BaseValidator:
stats = self.eval_json(stats) # update stats stats = self.eval_json(stats) # update stats
return stats return stats
def add_callback(self, event: str, callback):
"""
appends the given callback
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
overrides the existing callbacks with the given callback
"""
self.callbacks[event] = [callback]
def run_callbacks(self, event: str):
for callback in self.callbacks.get(event, []):
callback(self)
def get_dataloader(self, dataset_path, batch_size): def get_dataloader(self, dataset_path, batch_size):
raise NotImplementedError("get_dataloader function not implemented for this validator") raise NotImplementedError("get_dataloader function not implemented for this validator")

@ -1,3 +1,7 @@
# Ultralytics YOLO base callbacks
# Trainer callbacks ----------------------------------------------------------------------------------------------------
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
pass pass
@ -34,47 +38,71 @@ def on_train_epoch_end(trainer):
pass pass
def on_val_start(trainer): def on_fit_epoch_end(trainer):
pass pass
def on_val_batch_start(trainer): def on_model_save(trainer):
pass pass
def on_val_image_end(trainer): def on_train_end(trainer):
pass pass
def on_val_batch_end(trainer): def on_params_update(trainer):
pass pass
def on_val_end(trainer): def teardown(trainer):
pass pass
def on_fit_epoch_end(trainer): # Validator callbacks --------------------------------------------------------------------------------------------------
def on_val_start(validator):
pass pass
def on_model_save(trainer): def on_val_batch_start(validator):
pass pass
def on_train_end(trainer): def on_val_batch_end(validator):
pass pass
def on_params_update(trainer): def on_val_end(validator):
pass pass
def teardown(trainer): # Predictor callbacks --------------------------------------------------------------------------------------------------
def on_predict_start(predictor):
pass
def on_predict_batch_start(predictor):
pass
def on_predict_batch_end(predictor):
pass
def on_predict_end(predictor):
pass
# Exporter callbacks ---------------------------------------------------------------------------------------------------
def on_export_start(exporter):
pass
def on_export_end(exporter):
pass pass
default_callbacks = { default_callbacks = {
# Run in trainer
'on_pretrain_routine_start': on_pretrain_routine_start, 'on_pretrain_routine_start': on_pretrain_routine_start,
'on_pretrain_routine_end': on_pretrain_routine_end, 'on_pretrain_routine_end': on_pretrain_routine_end,
'on_train_start': on_train_start, 'on_train_start': on_train_start,
@ -84,16 +112,27 @@ default_callbacks = {
'on_before_zero_grad': on_before_zero_grad, 'on_before_zero_grad': on_before_zero_grad,
'on_train_batch_end': on_train_batch_end, 'on_train_batch_end': on_train_batch_end,
'on_train_epoch_end': on_train_epoch_end, 'on_train_epoch_end': on_train_epoch_end,
'on_val_start': on_val_start,
'on_val_batch_start': on_val_batch_start,
'on_val_image_end': on_val_image_end,
'on_val_batch_end': on_val_batch_end,
'on_val_end': on_val_end,
'on_fit_epoch_end': on_fit_epoch_end, # fit = train + val 'on_fit_epoch_end': on_fit_epoch_end, # fit = train + val
'on_model_save': on_model_save, 'on_model_save': on_model_save,
'on_train_end': on_train_end, 'on_train_end': on_train_end,
'on_params_update': on_params_update, 'on_params_update': on_params_update,
'teardown': teardown} 'teardown': teardown,
# Run in validator
'on_val_start': on_val_start,
'on_val_batch_start': on_val_batch_start,
'on_val_batch_end': on_val_batch_end,
'on_val_end': on_val_end,
# Run in predictor
'on_predict_start': on_predict_start,
'on_predict_batch_start': on_predict_batch_start,
'on_predict_batch_end': on_predict_batch_end,
'on_predict_end': on_predict_end,
# Run in exporter
'on_export_start': on_export_start,
'on_export_end': on_export_end}
def add_integration_callbacks(trainer): def add_integration_callbacks(trainer):

@ -18,7 +18,7 @@ def _log_images(imgs_dict, group="", step=0):
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
# TODO: reuse existing task # TODO: reuse existing task
task = Task.init(project_name=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8', task = Task.init(project_name=trainer.args.project or "YOLOv8",
task_name=trainer.args.name, task_name=trainer.args.name,
tags=['YOLOv8'], tags=['YOLOv8'],
output_uri=True, output_uri=True,
@ -32,7 +32,7 @@ def on_train_epoch_end(trainer):
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch) _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch)
def on_val_end(trainer): def on_fit_epoch_end(trainer):
if trainer.epoch == 0: if trainer.epoch == 0:
model_info = { model_info = {
"Parameters": get_num_params(trainer.model), "Parameters": get_num_params(trainer.model),
@ -50,5 +50,5 @@ def on_train_end(trainer):
callbacks = { callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start, "on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end, "on_train_epoch_end": on_train_epoch_end,
"on_val_end": on_val_end, "on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end} if clearml else {} "on_train_end": on_train_end} if clearml else {}

@ -17,11 +17,11 @@ def on_batch_end(trainer):
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
def on_val_end(trainer): def on_fit_epoch_end(trainer):
_log_scalars(trainer.metrics, trainer.epoch + 1) _log_scalars(trainer.metrics, trainer.epoch + 1)
callbacks = { callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start, "on_pretrain_routine_start": on_pretrain_routine_start,
"on_val_end": on_val_end, "on_fit_epoch_end": on_fit_epoch_end,
"on_batch_end": on_batch_end} "on_batch_end": on_batch_end}

@ -9,12 +9,11 @@ except (ImportError, AssertionError):
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
wandb.init(project=trainer.args.project if trainer.args.project != 'runs/train' else 'YOLOv8', wandb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=dict(
name=trainer.args.name, trainer.args)) if not wandb.run else wandb.run
config=dict(trainer.args)) if not wandb.run else wandb.run
def on_val_end(trainer): def on_fit_epoch_end(trainer):
wandb.run.log(trainer.metrics, step=trainer.epoch + 1) wandb.run.log(trainer.metrics, step=trainer.epoch + 1)
if trainer.epoch == 0: if trainer.epoch == 0:
model_info = { model_info = {
@ -42,5 +41,5 @@ def on_train_end(trainer):
callbacks = { callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start, "on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end, "on_train_epoch_end": on_train_epoch_end,
"on_val_end": on_val_end, "on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end} if wandb else {} "on_train_end": on_train_end} if wandb else {}

Loading…
Cancel
Save