Metrics and loss structure (#28)
Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>single_channel
parent
d0b3c9812b
commit
c5cb76b356
@ -1,3 +1,4 @@
|
|||||||
from .engine.trainer import BaseTrainer
|
from .engine.trainer import BaseTrainer
|
||||||
|
from .engine.validator import BaseValidator
|
||||||
|
|
||||||
__all__ = ["BaseTrainer"] # allow simpler import
|
__all__ = ["BaseTrainer", "BaseValidator"] # allow simpler import
|
||||||
|
@ -0,0 +1,105 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ultralytics.yolo.utils import Profile, select_device
|
||||||
|
|
||||||
|
|
||||||
|
class BaseValidator:
|
||||||
|
"""
|
||||||
|
Base validator class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataloader, device='', half=False, pbar=None, logger=None):
|
||||||
|
self.dataloader = dataloader
|
||||||
|
self.half = half
|
||||||
|
self.device = select_device(device, dataloader.batch_size)
|
||||||
|
self.pbar = pbar
|
||||||
|
self.logger = logger or logging.getLogger()
|
||||||
|
|
||||||
|
def __call__(self, trainer=None, model=None):
|
||||||
|
"""
|
||||||
|
Supports validation of a pre-trained model if passed or a model being trained
|
||||||
|
if trainer is passed (trainer gets priority).
|
||||||
|
"""
|
||||||
|
training = trainer is not None
|
||||||
|
# trainer = trainer or self.trainer_class.get_trainer()
|
||||||
|
assert training or model is not None, "Either trainer or model is needed for validation"
|
||||||
|
if training:
|
||||||
|
model = trainer.model
|
||||||
|
self.half &= self.device.type != 'cpu'
|
||||||
|
model = model.half() if self.half else model
|
||||||
|
else: # TODO: handle this when detectMultiBackend is supported
|
||||||
|
# model = DetectMultiBacked(model)
|
||||||
|
pass
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
dt = Profile(), Profile(), Profile(), Profile()
|
||||||
|
loss = 0
|
||||||
|
n_batches = len(self.dataloader)
|
||||||
|
desc = self.set_desc()
|
||||||
|
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
||||||
|
self.init_metrics()
|
||||||
|
with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'):
|
||||||
|
for images, labels in bar:
|
||||||
|
# pre-process
|
||||||
|
with dt[0]:
|
||||||
|
images, labels = self.preprocess_batch(images, labels)
|
||||||
|
|
||||||
|
# inference
|
||||||
|
with dt[1]:
|
||||||
|
preds = model(images)
|
||||||
|
# TODO: remember to add native augmentation support when implementing model, like:
|
||||||
|
# preds, train_out = model(im, augment=augment)
|
||||||
|
|
||||||
|
# loss
|
||||||
|
with dt[2]:
|
||||||
|
if training:
|
||||||
|
loss += trainer.criterion(preds, labels) / images.shape[0]
|
||||||
|
|
||||||
|
# pre-process predictions
|
||||||
|
with dt[3]:
|
||||||
|
preds = self.preprocess_preds(preds)
|
||||||
|
|
||||||
|
self.update_metrics(preds, labels)
|
||||||
|
|
||||||
|
stats = self.get_stats()
|
||||||
|
self.check_stats(stats)
|
||||||
|
|
||||||
|
self.print_results()
|
||||||
|
|
||||||
|
# print speeds
|
||||||
|
if not training:
|
||||||
|
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
|
||||||
|
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
|
||||||
|
self.logger.info(
|
||||||
|
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
|
||||||
|
|
||||||
|
# TODO: implement save json
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def preprocess_batch(self, images, labels):
|
||||||
|
return images.to(self.device, non_blocking=True), labels.to(self.device)
|
||||||
|
|
||||||
|
def preprocess_preds(self, preds):
|
||||||
|
return preds
|
||||||
|
|
||||||
|
def init_metrics(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update_metrics(self, preds, targets):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def check_stats(self, stats):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def print_results(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_desc(self):
|
||||||
|
pass
|
@ -0,0 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics import yolo
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationValidator(yolo.BaseValidator):
|
||||||
|
|
||||||
|
def init_metrics(self):
|
||||||
|
self.correct = torch.tensor([])
|
||||||
|
|
||||||
|
def update_metrics(self, preds, targets):
|
||||||
|
correct_in_batch = (targets[:, None] == preds).float()
|
||||||
|
self.correct = torch.cat((self.correct, correct_in_batch))
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy
|
||||||
|
top1, top5 = acc.mean(0).tolist()
|
||||||
|
return {"top1": top1, "top5": top5, "fitness": top5}
|
Loading…
Reference in new issue