Metrics and loss structure (#28)
Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -28,20 +28,11 @@ DEFAULT_CONFIG = "defaults.yaml"
|
||||
|
||||
class BaseTrainer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
data: str,
|
||||
criterion, # Should we create our own base loss classes? yolo.losses -> v8.losses.clfLoss
|
||||
validator=None,
|
||||
config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
|
||||
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
|
||||
self.console = LOGGER
|
||||
self.model = model
|
||||
self.data = data
|
||||
self.criterion = criterion # ComputeLoss object TODO: create yolo.Loss classes
|
||||
self.validator = val # Dummy validator
|
||||
self.model, self.data, self.train, self.hyps = self._get_config(config)
|
||||
self.validator = None
|
||||
self.callbacks = defaultdict(list)
|
||||
self.train, self.hyps = self._get_config(config)
|
||||
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug
|
||||
# Directories
|
||||
self.save_dir = utils.increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
|
||||
@ -57,7 +48,7 @@ class BaseTrainer:
|
||||
self.console.info(f"running on device {self.device}")
|
||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||
|
||||
# Model and Dataloaders. TBD: Should we move this inside trainer?
|
||||
# Model and Dataloaders.
|
||||
self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model
|
||||
self.model = self.get_model()
|
||||
self.model = self.model.to(self.device)
|
||||
@ -80,9 +71,9 @@ class BaseTrainer:
|
||||
try:
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
return config.train, config.hyps
|
||||
return config.model, config.data, config.train, config.hyps
|
||||
except KeyError as e:
|
||||
raise Exception("Missing key(s) in config") from e
|
||||
raise KeyError("Missing key(s) in config") from e
|
||||
|
||||
def add_callback(self, onevent: str, callback):
|
||||
"""
|
||||
@ -131,10 +122,9 @@ class BaseTrainer:
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
|
||||
if rank in {0, -1}:
|
||||
print(" Creating testloader rank :", rank)
|
||||
# self.test_loader = self.get_dataloader(self.testset,
|
||||
# batch_size=self.train.batch_size*2,
|
||||
# rank=rank)
|
||||
# print("created testloader :", rank)
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank)
|
||||
self.validator = self.get_validator()
|
||||
print("created testloader :", rank)
|
||||
|
||||
def _do_train(self, rank, world_size):
|
||||
if world_size > 1:
|
||||
@ -235,11 +225,8 @@ class BaseTrainer:
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_criterion(self, criterion):
|
||||
"""
|
||||
:param criterion: yolo.Loss object.
|
||||
"""
|
||||
self.criterion = criterion
|
||||
def get_validator(self):
|
||||
pass
|
||||
|
||||
def optimizer_step(self):
|
||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||
@ -265,6 +252,12 @@ class BaseTrainer:
|
||||
if not self.best_fitness or self.best_fitness < self.fitness:
|
||||
self.best_fitness = self.fitness
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
pass
|
||||
|
||||
def criterion(self, preds, targets):
|
||||
pass
|
||||
|
||||
def progress_string(self):
|
||||
"""
|
||||
Returns progress string depending on task type.
|
||||
|
@ -0,0 +1,105 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils import Profile, select_device
|
||||
|
||||
|
||||
class BaseValidator:
|
||||
"""
|
||||
Base validator class.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader, device='', half=False, pbar=None, logger=None):
|
||||
self.dataloader = dataloader
|
||||
self.half = half
|
||||
self.device = select_device(device, dataloader.batch_size)
|
||||
self.pbar = pbar
|
||||
self.logger = logger or logging.getLogger()
|
||||
|
||||
def __call__(self, trainer=None, model=None):
|
||||
"""
|
||||
Supports validation of a pre-trained model if passed or a model being trained
|
||||
if trainer is passed (trainer gets priority).
|
||||
"""
|
||||
training = trainer is not None
|
||||
# trainer = trainer or self.trainer_class.get_trainer()
|
||||
assert training or model is not None, "Either trainer or model is needed for validation"
|
||||
if training:
|
||||
model = trainer.model
|
||||
self.half &= self.device.type != 'cpu'
|
||||
model = model.half() if self.half else model
|
||||
else: # TODO: handle this when detectMultiBackend is supported
|
||||
# model = DetectMultiBacked(model)
|
||||
pass
|
||||
|
||||
model.eval()
|
||||
dt = Profile(), Profile(), Profile(), Profile()
|
||||
loss = 0
|
||||
n_batches = len(self.dataloader)
|
||||
desc = self.set_desc()
|
||||
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
||||
self.init_metrics()
|
||||
with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'):
|
||||
for images, labels in bar:
|
||||
# pre-process
|
||||
with dt[0]:
|
||||
images, labels = self.preprocess_batch(images, labels)
|
||||
|
||||
# inference
|
||||
with dt[1]:
|
||||
preds = model(images)
|
||||
# TODO: remember to add native augmentation support when implementing model, like:
|
||||
# preds, train_out = model(im, augment=augment)
|
||||
|
||||
# loss
|
||||
with dt[2]:
|
||||
if training:
|
||||
loss += trainer.criterion(preds, labels) / images.shape[0]
|
||||
|
||||
# pre-process predictions
|
||||
with dt[3]:
|
||||
preds = self.preprocess_preds(preds)
|
||||
|
||||
self.update_metrics(preds, labels)
|
||||
|
||||
stats = self.get_stats()
|
||||
self.check_stats(stats)
|
||||
|
||||
self.print_results()
|
||||
|
||||
# print speeds
|
||||
if not training:
|
||||
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
|
||||
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
|
||||
self.logger.info(
|
||||
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
|
||||
|
||||
# TODO: implement save json
|
||||
|
||||
return stats
|
||||
|
||||
def preprocess_batch(self, images, labels):
|
||||
return images.to(self.device, non_blocking=True), labels.to(self.device)
|
||||
|
||||
def preprocess_preds(self, preds):
|
||||
return preds
|
||||
|
||||
def init_metrics(self):
|
||||
pass
|
||||
|
||||
def update_metrics(self, preds, targets):
|
||||
pass
|
||||
|
||||
def get_stats(self):
|
||||
pass
|
||||
|
||||
def check_stats(self, stats):
|
||||
pass
|
||||
|
||||
def print_results(self):
|
||||
pass
|
||||
|
||||
def set_desc(self):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user