Metrics and loss structure (#28)

Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent d0b3c9812b
commit c5cb76b356
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,8 +10,8 @@ pip install . # (dev)
# pip install ultralytics (production)
```
### Usage
```python
import ultralytics
from ultralytics import HUB, YOLO

@ -1,3 +1,4 @@
from .engine.trainer import BaseTrainer
from .engine.validator import BaseValidator
__all__ = ["BaseTrainer"] # allow simpler import
__all__ = ["BaseTrainer", "BaseValidator"] # allow simpler import

@ -2,13 +2,10 @@ from itertools import repeat
from multiprocessing.pool import Pool
from pathlib import Path
import cv2
import numpy as np
import torch
import torchvision
from tqdm import tqdm
from ..utils.general import LOGGER, NUM_THREADS
from ..utils.general import NUM_THREADS
from .augment import *
from .base import BaseDataset
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label

@ -28,20 +28,11 @@ DEFAULT_CONFIG = "defaults.yaml"
class BaseTrainer:
def __init__(
self,
model: str,
data: str,
criterion, # Should we create our own base loss classes? yolo.losses -> v8.losses.clfLoss
validator=None,
config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
self.console = LOGGER
self.model = model
self.data = data
self.criterion = criterion # ComputeLoss object TODO: create yolo.Loss classes
self.validator = val # Dummy validator
self.model, self.data, self.train, self.hyps = self._get_config(config)
self.validator = None
self.callbacks = defaultdict(list)
self.train, self.hyps = self._get_config(config)
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug
# Directories
self.save_dir = utils.increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
@ -57,7 +48,7 @@ class BaseTrainer:
self.console.info(f"running on device {self.device}")
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
# Model and Dataloaders. TBD: Should we move this inside trainer?
# Model and Dataloaders.
self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model
self.model = self.get_model()
self.model = self.model.to(self.device)
@ -80,9 +71,9 @@ class BaseTrainer:
try:
if isinstance(config, (str, Path)):
config = OmegaConf.load(config)
return config.train, config.hyps
return config.model, config.data, config.train, config.hyps
except KeyError as e:
raise Exception("Missing key(s) in config") from e
raise KeyError("Missing key(s) in config") from e
def add_callback(self, onevent: str, callback):
"""
@ -131,10 +122,9 @@ class BaseTrainer:
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
if rank in {0, -1}:
print(" Creating testloader rank :", rank)
# self.test_loader = self.get_dataloader(self.testset,
# batch_size=self.train.batch_size*2,
# rank=rank)
# print("created testloader :", rank)
self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank)
self.validator = self.get_validator()
print("created testloader :", rank)
def _do_train(self, rank, world_size):
if world_size > 1:
@ -235,11 +225,8 @@ class BaseTrainer:
"""
pass
def set_criterion(self, criterion):
"""
:param criterion: yolo.Loss object.
"""
self.criterion = criterion
def get_validator(self):
pass
def optimizer_step(self):
self.scaler.unscale_(self.optimizer) # unscale gradients
@ -265,6 +252,12 @@ class BaseTrainer:
if not self.best_fitness or self.best_fitness < self.fitness:
self.best_fitness = self.fitness
def build_targets(self, preds, targets):
pass
def criterion(self, preds, targets):
pass
def progress_string(self):
"""
Returns progress string depending on task type.

@ -0,0 +1,105 @@
import logging
import torch
from tqdm import tqdm
from ultralytics.yolo.utils import Profile, select_device
class BaseValidator:
"""
Base validator class.
"""
def __init__(self, dataloader, device='', half=False, pbar=None, logger=None):
self.dataloader = dataloader
self.half = half
self.device = select_device(device, dataloader.batch_size)
self.pbar = pbar
self.logger = logger or logging.getLogger()
def __call__(self, trainer=None, model=None):
"""
Supports validation of a pre-trained model if passed or a model being trained
if trainer is passed (trainer gets priority).
"""
training = trainer is not None
# trainer = trainer or self.trainer_class.get_trainer()
assert training or model is not None, "Either trainer or model is needed for validation"
if training:
model = trainer.model
self.half &= self.device.type != 'cpu'
model = model.half() if self.half else model
else: # TODO: handle this when detectMultiBackend is supported
# model = DetectMultiBacked(model)
pass
model.eval()
dt = Profile(), Profile(), Profile(), Profile()
loss = 0
n_batches = len(self.dataloader)
desc = self.set_desc()
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
self.init_metrics()
with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'):
for images, labels in bar:
# pre-process
with dt[0]:
images, labels = self.preprocess_batch(images, labels)
# inference
with dt[1]:
preds = model(images)
# TODO: remember to add native augmentation support when implementing model, like:
# preds, train_out = model(im, augment=augment)
# loss
with dt[2]:
if training:
loss += trainer.criterion(preds, labels) / images.shape[0]
# pre-process predictions
with dt[3]:
preds = self.preprocess_preds(preds)
self.update_metrics(preds, labels)
stats = self.get_stats()
self.check_stats(stats)
self.print_results()
# print speeds
if not training:
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
self.logger.info(
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
# TODO: implement save json
return stats
def preprocess_batch(self, images, labels):
return images.to(self.device, non_blocking=True), labels.to(self.device)
def preprocess_preds(self, preds):
return preds
def init_metrics(self):
pass
def update_metrics(self, preds, targets):
pass
def get_stats(self):
pass
def check_stats(self, stats):
pass
def print_results(self):
pass
def set_desc(self):
pass

@ -1,4 +1,4 @@
from .general import WorkingDirectory, check_version, download, increment_path, save_yaml
from .general import Profile, WorkingDirectory, check_version, download, increment_path, save_yaml
from .torch_utils import LOCAL_RANK, RANK, WORLD_SIZE, DDP_model, select_device, torch_distributed_zero_first
__all__ = [
@ -8,6 +8,7 @@ __all__ = [
"WorkingDirectory",
"download",
"check_version",
"Profile",
# torch
"torch_distributed_zero_first",
"LOCAL_RANK",

@ -1,3 +1,5 @@
model: null
data: null
train:
epochs: 300
batch_size: 16

@ -5,6 +5,7 @@ import logging
import os
import platform
import subprocess
import time
import urllib
from itertools import repeat
from multiprocessing.pool import ThreadPool
@ -208,7 +209,7 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
return path
def save_yaml(file='data.yaml', data={}):
def save_yaml(file='data.yaml', data=None):
# Single-line safe yaml saving
with open(file, 'w') as f:
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
@ -278,7 +279,6 @@ class WorkingDirectory(contextlib.ContextDecorator):
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
from utils.general import LOGGER
file = Path(file)
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
@ -301,7 +301,6 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
from utils.general import LOGGER
def github_assets(repository, version='latest'):
# Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
@ -351,3 +350,23 @@ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
def get_model(model: str):
# check for local weights
pass
class Profile(contextlib.ContextDecorator):
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
def __init__(self, t=0.0):
self.t = t
self.cuda = torch.cuda.is_available()
def __enter__(self):
self.start = self.time()
return self
def __exit__(self, type, value, traceback):
self.dt = self.time() - self.start # delta-time
self.t += self.dt # accumulate dt
def time(self):
if self.cuda:
torch.cuda.synchronize()
return time.time()

@ -2,6 +2,7 @@
"""
Model validation metrics
"""
import numpy as np

@ -4,10 +4,8 @@ from pathlib import Path
import hydra
import torch
import torch.hub as hub
import torchvision
import torchvision.transforms as T
from omegaconf import DictConfig, OmegaConf
from val import ClassificationValidator
from ultralytics.yolo import BaseTrainer, utils, v8
from ultralytics.yolo.data import build_classification_dataloader
@ -15,7 +13,7 @@ from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG
# BaseTrainer python usage
class Trainer(BaseTrainer):
class ClassificationTrainer(BaseTrainer):
def get_dataset(self):
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
@ -55,13 +53,18 @@ class Trainer(BaseTrainer):
return model
def get_validator(self):
return ClassificationValidator(self.test_loader, self.device, logger=self.console) # validator
def criterion(self, preds, targets):
return torch.nn.functional.cross_entropy(preds, targets)
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
def train(cfg):
model = "squeezenet1_0"
dataset = "imagenette160" # or yolo.ClassificationDataset("mnist")
criterion = torch.nn.CrossEntropyLoss() # yolo.Loss object
trainer = Trainer(model, dataset, criterion, config=cfg)
cfg.model = cfg.model or "squeezenet1_0"
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
trainer = ClassificationTrainer(cfg)
trainer.run()

@ -0,0 +1,18 @@
import torch
from ultralytics import yolo
class ClassificationValidator(yolo.BaseValidator):
def init_metrics(self):
self.correct = torch.tensor([])
def update_metrics(self, preds, targets):
correct_in_batch = (targets[:, None] == preds).float()
self.correct = torch.cat((self.correct, correct_in_batch))
def get_stats(self):
acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy
top1, top5 = acc.mean(0).tolist()
return {"top1": top1, "top5": top5, "fitness": top5}
Loading…
Cancel
Save