Fix Classification train logging (#157)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher
2023-01-08 17:52:44 +01:00
committed by GitHub
parent d387359f74
commit e79ea1666c
7 changed files with 86 additions and 40 deletions

View File

@ -1,5 +1,4 @@
import hydra
import torch
from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.validator import BaseValidator
@ -13,8 +12,12 @@ class ClassificationValidator(BaseValidator):
super().__init__(dataloader, save_dir, pbar, logger, args)
self.metrics = ClassifyMetrics()
def get_desc(self):
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
def init_metrics(self, model):
self.correct = torch.tensor([], device=next(model.parameters()).device)
self.pred = []
self.targets = []
def preprocess(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True)
@ -23,17 +26,20 @@ class ClassificationValidator(BaseValidator):
return batch
def update_metrics(self, preds, batch):
targets = batch["cls"]
correct_in_batch = (targets[:, None] == preds).float()
self.correct = torch.cat((self.correct, correct_in_batch))
self.pred.append(preds.argsort(1, descending=True)[:, :5])
self.targets.append(batch["cls"])
def get_stats(self):
self.metrics.process(self.correct)
self.metrics.process(self.targets, self.pred)
return self.metrics.results_dict
def get_dataloader(self, dataset_path, batch_size):
return build_classification_dataloader(path=dataset_path, imgsz=self.args.imgsz, batch_size=batch_size)
def print_results(self):
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5))
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def val(cfg):