Fix ClearML Mosaic callback to 'on_train_epoch_end' (#92)

single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 0298821467
commit 249dfbdc05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -47,6 +47,7 @@ class BaseValidator:
model = model.half() if self.args.half else model.float() model = model.half() if self.args.half else model.float()
self.model = model self.model = model
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device) self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
self.args.plots = trainer.epoch == trainer.epochs - 1 # always plot final epoch
else: else:
assert model is not None, "Either trainer or model is needed for validation" assert model is not None, "Either trainer or model is needed for validation"
self.device = select_device(self.args.device, self.args.batch_size) self.device = select_device(self.args.device, self.args.batch_size)

@ -1,6 +1,3 @@
import os
from pathlib import Path
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
try: try:
@ -30,12 +27,9 @@ def on_train_start(trainer):
task.connect(dict(trainer.args), name='General') task.connect(dict(trainer.args), name='General')
def on_epoch_start(trainer): def on_train_epoch_end(trainer):
if trainer.epoch == 1: if trainer.epoch == 1:
plots = [filename for filename in os.listdir(trainer.save_dir) if filename.startswith("train_batch")] _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch)
imgs_dict = {f"train_batch_{i}": Path(trainer.save_dir) / img for i, img in enumerate(plots)}
if imgs_dict:
_log_images(imgs_dict, "Mosaic", trainer.epoch)
def on_val_end(trainer): def on_val_end(trainer):
@ -55,6 +49,6 @@ def on_train_end(trainer):
callbacks = { callbacks = {
"on_train_start": on_train_start, "on_train_start": on_train_start,
"on_epoch_start": on_epoch_start, "on_train_epoch_end": on_train_epoch_end,
"on_val_end": on_val_end, "on_val_end": on_val_end,
"on_train_end": on_train_end} if clearml else {} "on_train_end": on_train_end} if clearml else {}

@ -343,7 +343,7 @@ def compute_ap(recall, precision):
return ap, mpre, mrec return ap, mpre, mrec
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16, prefix=""): def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=""):
""" Compute the average precision, given the recall and precision curves. """ Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments # Arguments
@ -398,10 +398,10 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
names = dict(enumerate(names)) # to dict names = dict(enumerate(names)) # to dict
if plot: if plot:
plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names) plot_pr_curve(px, py, ap, save_dir / f'{prefix}PR_curve.png', names)
plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1') plot_mc_curve(px, f1, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1')
plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision') plot_mc_curve(px, p, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision')
plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall') plot_mc_curve(px, r, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall')
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p[:, i], r[:, i], f1[:, i] p, r, f1 = p[:, i], r[:, i], f1[:, i]

@ -63,16 +63,15 @@ class DetectionTrainer(BaseTrainer):
return dict(zip(keys, loss_items)) if loss_items is not None else keys return dict(zip(keys, loss_items)) if loss_items is not None else keys
def progress_string(self): def progress_string(self):
return ('\n' + '%11s' * 7) % \ return ('\n' + '%11s' * 7) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def plot_training_samples(self, batch, ni): def plot_training_samples(self, batch, ni):
images = batch["img"] plot_images(images=batch["img"],
cls = batch["cls"].squeeze(-1) batch_idx=batch["batch_idx"],
bboxes = batch["bboxes"] cls=batch["cls"].squeeze(-1),
paths = batch["im_file"] bboxes=batch["bboxes"],
batch_idx = batch["batch_idx"] paths=batch["im_file"],
plot_images(images, batch_idx, cls, bboxes, paths=paths, fname=self.save_dir / f"train_batch{ni}.jpg") fname=self.save_dir / f"train_batch{ni}.jpg")
def plot_metrics(self): def plot_metrics(self):
plot_results(file=self.csv) # save results.png plot_results(file=self.csv) # save results.png

@ -214,8 +214,7 @@ class SegmentationTrainer(DetectionTrainer):
return dict(zip(keys, loss_items)) if loss_items is not None else keys return dict(zip(keys, loss_items)) if loss_items is not None else keys
def progress_string(self): def progress_string(self):
return ('\n' + '%11s' * 8) % \ return ('\n' + '%11s' * 8) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def plot_training_samples(self, batch, ni): def plot_training_samples(self, batch, ni):
images = batch["img"] images = batch["img"]

Loading…
Cancel
Save