import logging import torch from tqdm import tqdm from ultralytics.yolo.utils import Profile, select_device class BaseValidator: """ Base validator class. """ def __init__(self, dataloader, device='', half=False, pbar=None, logger=None): self.dataloader = dataloader self.half = half self.device = select_device(device, dataloader.batch_size) self.pbar = pbar self.logger = logger or logging.getLogger() def __call__(self, trainer=None, model=None): """ Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer gets priority). """ training = trainer is not None # trainer = trainer or self.trainer_class.get_trainer() assert training or model is not None, "Either trainer or model is needed for validation" if training: model = trainer.model self.half &= self.device.type != 'cpu' model = model.half() if self.half else model else: # TODO: handle this when detectMultiBackend is supported # model = DetectMultiBacked(model) pass model.eval() dt = Profile(), Profile(), Profile(), Profile() loss = 0 n_batches = len(self.dataloader) desc = self.set_desc() bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') self.init_metrics() with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'): for images, labels in bar: # pre-process with dt[0]: images, labels = self.preprocess_batch(images, labels) # inference with dt[1]: preds = model(images) # TODO: remember to add native augmentation support when implementing model, like: # preds, train_out = model(im, augment=augment) # loss with dt[2]: if training: loss += trainer.criterion(preds, labels) / images.shape[0] # pre-process predictions with dt[3]: preds = self.preprocess_preds(preds) self.update_metrics(preds, labels) stats = self.get_stats() self.check_stats(stats) self.print_results() # print speeds if not training: t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image # shape = (self.dataloader.batch_size, 3, imgsz, imgsz) self.logger.info( 'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t) # TODO: implement save json return stats def preprocess_batch(self, images, labels): return images.to(self.device, non_blocking=True), labels.to(self.device) def preprocess_preds(self, preds): return preds def init_metrics(self): pass def update_metrics(self, preds, targets): pass def get_stats(self): pass def check_stats(self, stats): pass def print_results(self): pass def set_desc(self): pass