Update metrics names (#85)

This commit is contained in:
Glenn Jocher
2022-12-24 02:32:24 +01:00
committed by GitHub
parent 6432afc5f9
commit 248d54ca03
9 changed files with 30 additions and 36 deletions

View File

@ -46,6 +46,7 @@ class DetectionTrainer(BaseTrainer):
return model
def get_validator(self):
self.loss_names = 'box_loss', 'obj_loss', 'cls_loss'
return v8.detect.DetectionValidator(self.test_loader,
save_dir=self.save_dir,
logger=self.console,
@ -190,15 +191,14 @@ class DetectionTrainer(BaseTrainer):
loss = lbox + lobj + lcls
return loss * bs, torch.cat((lbox, lobj, lcls)).detach()
# TODO: improve from API users perspective
def label_loss_items(self, loss_items=None, prefix="train"):
# We should just use named tensors here in future
keys = [f"{prefix}/lbox", f"{prefix}/lobj", f"{prefix}/lcls"]
keys = [f"{prefix}/{x}" for x in self.loss_names]
return dict(zip(keys, loss_items)) if loss_items is not None else keys
def progress_string(self):
return ('\n' + '%11s' * 6) % \
('Epoch', 'GPU_mem', 'box_loss', 'obj_loss', 'cls_loss', 'Size')
('Epoch', 'GPU_mem', *self.loss_names, 'Size')
def plot_training_samples(self, batch, ni):
images = batch["img"]