Segmentation support & other enchancements (#40)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2022-11-08 20:57:57 +05:30
committed by GitHub
parent c617ee1c79
commit f56c9bcc26
17 changed files with 1320 additions and 47 deletions

View File

@ -1,12 +1,17 @@
"""
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
"""
# TODOs
# 1. finish _set_model_attributes
# 2. allow num_class update for both pretrained and csv_loaded models
# 3. save
import os
import time
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from telnetlib import TLS
from typing import Dict, Union
import torch
@ -52,6 +57,8 @@ class BaseTrainer:
# Model and Dataloaders.
self.trainset, self.testset = self.get_dataset(self.args.data)
if self.args.cfg is not None:
self.model = self.load_cfg(self.args.cfg)
if self.args.model is not None:
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
@ -133,6 +140,20 @@ class BaseTrainer:
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank)
self.validator = self.get_validator()
print("created testloader :", rank)
self.console.info(self.progress_string())
def _set_model_attributes(self):
# TODO: fix and use after self.data_dict is available
'''
head = utils.torch_utils.de_parallel(self.model).model[-1]
self.args.box *= 3 / head.nl # scale to layers
self.args.cls *= head.nc / 80 * 3 / head.nl # scale to classes and layers
self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl # scale to image size and layers
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
model.names = names
'''
def _do_train(self, rank, world_size):
if world_size > 1:
@ -153,13 +174,17 @@ class BaseTrainer:
pbar = tqdm(enumerate(self.train_loader),
total=len(self.train_loader),
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
tloss = 0
for i, (images, labels) in pbar:
tloss = None
for i, batch in pbar:
# img, label (classification)/ img, targets, paths, _, masks(detection)
# callback hook. on_batch_start
# forward
images, labels = self.preprocess_batch(images, labels)
self.loss = self.criterion(self.model(images), labels)
tloss = (tloss * i + self.loss.item()) / (i + 1)
batch = self.preprocess_batch(batch)
# TODO: warmup, multiscale
preds = self.model(batch["img"])
self.loss, self.loss_items = self.criterion(preds, batch)
tloss = (tloss * i + self.loss_items) / (i + 1) if tloss is not None else self.loss_items
# backward
self.model.zero_grad(set_to_none=True)
@ -170,9 +195,13 @@ class BaseTrainer:
self.trigger_callbacks('on_batch_end')
# log
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
loss_len = tloss.shape[0] if len(tloss.size()) else 1
losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0)
if rank in {-1, 0}:
pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
pbar.set_description(
(" {} " + "{:.3f} " * (2 + loss_len)).format(f'{epoch + 1}/{self.args.epochs}', mem, *losses,
batch["img"].shape[-1]))
if rank in [-1, 0]:
# validation
@ -240,6 +269,9 @@ class BaseTrainer:
return model
def load_cfg(self, cfg):
raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self):
pass
@ -250,11 +282,11 @@ class BaseTrainer:
self.scaler.update()
self.optimizer.zero_grad()
def preprocess_batch(self, images, labels):
def preprocess_batch(self, batch):
"""
Allows custom preprocessing model inputs and ground truths depending on task type
"""
return images.to(self.device, non_blocking=True), labels.to(self.device)
return batch
def validate(self):
"""
@ -270,14 +302,17 @@ class BaseTrainer:
def build_targets(self, preds, targets):
pass
def criterion(self, preds, targets):
def criterion(self, preds, batch):
"""
Returns loss and individual loss items as Tensor
"""
pass
def progress_string(self):
"""
Returns progress string depending on task type.
"""
pass
return ''
def usage_help(self):
"""

View File

@ -1,8 +1,10 @@
import logging
import torch
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import select_device
@ -12,12 +14,15 @@ class BaseValidator:
Base validator class.
"""
def __init__(self, dataloader, device='', half=False, pbar=None, logger=None):
def __init__(self, dataloader, pbar=None, logger=None, args=None):
self.dataloader = dataloader
self.half = half
self.device = select_device(device, dataloader.batch_size)
self.pbar = pbar
self.logger = logger or logging.getLogger()
self.args = args or OmegaConf.load(DEFAULT_CONFIG)
self.device = select_device(self.args.device, dataloader.batch_size)
self.cuda = self.device.type != 'cpu'
self.batch_i = None
self.training = True
def __call__(self, trainer=None, model=None):
"""
@ -25,45 +30,48 @@ class BaseValidator:
if trainer is passed (trainer gets priority).
"""
training = trainer is not None
self.training = training
# trainer = trainer or self.trainer_class.get_trainer()
assert training or model is not None, "Either trainer or model is needed for validation"
if training:
model = trainer.model
self.half &= self.device.type != 'cpu'
model = model.half() if self.half else model
self.args.half &= self.device.type != 'cpu'
model = model.half() if self.args.half else model
else: # TODO: handle this when detectMultiBackend is supported
# model = DetectMultiBacked(model)
pass
# TODO: implement init_model_attributes()
model.eval()
dt = Profile(), Profile(), Profile(), Profile()
loss = 0
n_batches = len(self.dataloader)
desc = self.set_desc()
desc = self.get_desc()
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
self.init_metrics()
self.init_metrics(model)
with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'):
for images, labels in bar:
for batch_i, batch in enumerate(bar):
self.batch_i = batch_i
# pre-process
with dt[0]:
images, labels = self.preprocess_batch(images, labels)
batch = self.preprocess_batch(batch)
# inference
with dt[1]:
preds = model(images)
preds = model(batch["img"])
# TODO: remember to add native augmentation support when implementing model, like:
# preds, train_out = model(im, augment=augment)
# loss
with dt[2]:
if training:
loss += trainer.criterion(preds, labels) / images.shape[0]
loss += trainer.criterion(preds, batch)[0]
# pre-process predictions
with dt[3]:
preds = self.preprocess_preds(preds)
self.update_metrics(preds, labels)
self.update_metrics(preds, batch)
stats = self.get_stats()
self.check_stats(stats)
@ -81,8 +89,8 @@ class BaseValidator:
return stats
def preprocess_batch(self, images, labels):
return images.to(self.device, non_blocking=True), labels.to(self.device)
def preprocess_batch(self, batch):
return batch
def preprocess_preds(self, preds):
return preds
@ -90,7 +98,7 @@ class BaseValidator:
def init_metrics(self):
pass
def update_metrics(self, preds, targets):
def update_metrics(self, preds, batch):
pass
def get_stats(self):
@ -102,5 +110,5 @@ class BaseValidator:
def print_results(self):
pass
def set_desc(self):
def get_desc(self):
pass