Metrics and loss structure (#28)

Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent d0b3c9812b
commit c5cb76b356
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,8 +10,8 @@ pip install . # (dev)
# pip install ultralytics (production) # pip install ultralytics (production)
``` ```
### Usage ### Usage
```python ```python
import ultralytics import ultralytics
from ultralytics import HUB, YOLO from ultralytics import HUB, YOLO

@ -1,3 +1,4 @@
from .engine.trainer import BaseTrainer from .engine.trainer import BaseTrainer
from .engine.validator import BaseValidator
__all__ = ["BaseTrainer"] # allow simpler import __all__ = ["BaseTrainer", "BaseValidator"] # allow simpler import

@ -2,13 +2,10 @@ from itertools import repeat
from multiprocessing.pool import Pool from multiprocessing.pool import Pool
from pathlib import Path from pathlib import Path
import cv2
import numpy as np
import torch
import torchvision import torchvision
from tqdm import tqdm from tqdm import tqdm
from ..utils.general import LOGGER, NUM_THREADS from ..utils.general import NUM_THREADS
from .augment import * from .augment import *
from .base import BaseDataset from .base import BaseDataset
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label

@ -28,20 +28,11 @@ DEFAULT_CONFIG = "defaults.yaml"
class BaseTrainer: class BaseTrainer:
def __init__( def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
self,
model: str,
data: str,
criterion, # Should we create our own base loss classes? yolo.losses -> v8.losses.clfLoss
validator=None,
config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
self.console = LOGGER self.console = LOGGER
self.model = model self.model, self.data, self.train, self.hyps = self._get_config(config)
self.data = data self.validator = None
self.criterion = criterion # ComputeLoss object TODO: create yolo.Loss classes
self.validator = val # Dummy validator
self.callbacks = defaultdict(list) self.callbacks = defaultdict(list)
self.train, self.hyps = self._get_config(config)
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug
# Directories # Directories
self.save_dir = utils.increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok) self.save_dir = utils.increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
@ -57,7 +48,7 @@ class BaseTrainer:
self.console.info(f"running on device {self.device}") self.console.info(f"running on device {self.device}")
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu') self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
# Model and Dataloaders. TBD: Should we move this inside trainer? # Model and Dataloaders.
self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model
self.model = self.get_model() self.model = self.get_model()
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
@ -80,9 +71,9 @@ class BaseTrainer:
try: try:
if isinstance(config, (str, Path)): if isinstance(config, (str, Path)):
config = OmegaConf.load(config) config = OmegaConf.load(config)
return config.train, config.hyps return config.model, config.data, config.train, config.hyps
except KeyError as e: except KeyError as e:
raise Exception("Missing key(s) in config") from e raise KeyError("Missing key(s) in config") from e
def add_callback(self, onevent: str, callback): def add_callback(self, onevent: str, callback):
""" """
@ -131,10 +122,9 @@ class BaseTrainer:
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank) self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
if rank in {0, -1}: if rank in {0, -1}:
print(" Creating testloader rank :", rank) print(" Creating testloader rank :", rank)
# self.test_loader = self.get_dataloader(self.testset, self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank)
# batch_size=self.train.batch_size*2, self.validator = self.get_validator()
# rank=rank) print("created testloader :", rank)
# print("created testloader :", rank)
def _do_train(self, rank, world_size): def _do_train(self, rank, world_size):
if world_size > 1: if world_size > 1:
@ -235,11 +225,8 @@ class BaseTrainer:
""" """
pass pass
def set_criterion(self, criterion): def get_validator(self):
""" pass
:param criterion: yolo.Loss object.
"""
self.criterion = criterion
def optimizer_step(self): def optimizer_step(self):
self.scaler.unscale_(self.optimizer) # unscale gradients self.scaler.unscale_(self.optimizer) # unscale gradients
@ -265,6 +252,12 @@ class BaseTrainer:
if not self.best_fitness or self.best_fitness < self.fitness: if not self.best_fitness or self.best_fitness < self.fitness:
self.best_fitness = self.fitness self.best_fitness = self.fitness
def build_targets(self, preds, targets):
pass
def criterion(self, preds, targets):
pass
def progress_string(self): def progress_string(self):
""" """
Returns progress string depending on task type. Returns progress string depending on task type.

@ -0,0 +1,105 @@
import logging
import torch
from tqdm import tqdm
from ultralytics.yolo.utils import Profile, select_device
class BaseValidator:
"""
Base validator class.
"""
def __init__(self, dataloader, device='', half=False, pbar=None, logger=None):
self.dataloader = dataloader
self.half = half
self.device = select_device(device, dataloader.batch_size)
self.pbar = pbar
self.logger = logger or logging.getLogger()
def __call__(self, trainer=None, model=None):
"""
Supports validation of a pre-trained model if passed or a model being trained
if trainer is passed (trainer gets priority).
"""
training = trainer is not None
# trainer = trainer or self.trainer_class.get_trainer()
assert training or model is not None, "Either trainer or model is needed for validation"
if training:
model = trainer.model
self.half &= self.device.type != 'cpu'
model = model.half() if self.half else model
else: # TODO: handle this when detectMultiBackend is supported
# model = DetectMultiBacked(model)
pass
model.eval()
dt = Profile(), Profile(), Profile(), Profile()
loss = 0
n_batches = len(self.dataloader)
desc = self.set_desc()
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
self.init_metrics()
with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'):
for images, labels in bar:
# pre-process
with dt[0]:
images, labels = self.preprocess_batch(images, labels)
# inference
with dt[1]:
preds = model(images)
# TODO: remember to add native augmentation support when implementing model, like:
# preds, train_out = model(im, augment=augment)
# loss
with dt[2]:
if training:
loss += trainer.criterion(preds, labels) / images.shape[0]
# pre-process predictions
with dt[3]:
preds = self.preprocess_preds(preds)
self.update_metrics(preds, labels)
stats = self.get_stats()
self.check_stats(stats)
self.print_results()
# print speeds
if not training:
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
self.logger.info(
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
# TODO: implement save json
return stats
def preprocess_batch(self, images, labels):
return images.to(self.device, non_blocking=True), labels.to(self.device)
def preprocess_preds(self, preds):
return preds
def init_metrics(self):
pass
def update_metrics(self, preds, targets):
pass
def get_stats(self):
pass
def check_stats(self, stats):
pass
def print_results(self):
pass
def set_desc(self):
pass

@ -1,4 +1,4 @@
from .general import WorkingDirectory, check_version, download, increment_path, save_yaml from .general import Profile, WorkingDirectory, check_version, download, increment_path, save_yaml
from .torch_utils import LOCAL_RANK, RANK, WORLD_SIZE, DDP_model, select_device, torch_distributed_zero_first from .torch_utils import LOCAL_RANK, RANK, WORLD_SIZE, DDP_model, select_device, torch_distributed_zero_first
__all__ = [ __all__ = [
@ -8,6 +8,7 @@ __all__ = [
"WorkingDirectory", "WorkingDirectory",
"download", "download",
"check_version", "check_version",
"Profile",
# torch # torch
"torch_distributed_zero_first", "torch_distributed_zero_first",
"LOCAL_RANK", "LOCAL_RANK",

@ -1,3 +1,5 @@
model: null
data: null
train: train:
epochs: 300 epochs: 300
batch_size: 16 batch_size: 16

@ -5,6 +5,7 @@ import logging
import os import os
import platform import platform
import subprocess import subprocess
import time
import urllib import urllib
from itertools import repeat from itertools import repeat
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
@ -208,7 +209,7 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
return path return path
def save_yaml(file='data.yaml', data={}): def save_yaml(file='data.yaml', data=None):
# Single-line safe yaml saving # Single-line safe yaml saving
with open(file, 'w') as f: with open(file, 'w') as f:
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False) yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
@ -278,7 +279,6 @@ class WorkingDirectory(contextlib.ContextDecorator):
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
from utils.general import LOGGER
file = Path(file) file = Path(file)
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}" assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
@ -301,7 +301,6 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'): def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc. # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
from utils.general import LOGGER
def github_assets(repository, version='latest'): def github_assets(repository, version='latest'):
# Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...]) # Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
@ -351,3 +350,23 @@ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
def get_model(model: str): def get_model(model: str):
# check for local weights # check for local weights
pass pass
class Profile(contextlib.ContextDecorator):
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
def __init__(self, t=0.0):
self.t = t
self.cuda = torch.cuda.is_available()
def __enter__(self):
self.start = self.time()
return self
def __exit__(self, type, value, traceback):
self.dt = self.time() - self.start # delta-time
self.t += self.dt # accumulate dt
def time(self):
if self.cuda:
torch.cuda.synchronize()
return time.time()

@ -2,6 +2,7 @@
""" """
Model validation metrics Model validation metrics
""" """
import numpy as np import numpy as np

@ -4,10 +4,8 @@ from pathlib import Path
import hydra import hydra
import torch import torch
import torch.hub as hub
import torchvision import torchvision
import torchvision.transforms as T from val import ClassificationValidator
from omegaconf import DictConfig, OmegaConf
from ultralytics.yolo import BaseTrainer, utils, v8 from ultralytics.yolo import BaseTrainer, utils, v8
from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.data import build_classification_dataloader
@ -15,7 +13,7 @@ from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG
# BaseTrainer python usage # BaseTrainer python usage
class Trainer(BaseTrainer): class ClassificationTrainer(BaseTrainer):
def get_dataset(self): def get_dataset(self):
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module # temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
@ -55,13 +53,18 @@ class Trainer(BaseTrainer):
return model return model
def get_validator(self):
return ClassificationValidator(self.test_loader, self.device, logger=self.console) # validator
def criterion(self, preds, targets):
return torch.nn.functional.cross_entropy(preds, targets)
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0]) @hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
def train(cfg): def train(cfg):
model = "squeezenet1_0" cfg.model = cfg.model or "squeezenet1_0"
dataset = "imagenette160" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
criterion = torch.nn.CrossEntropyLoss() # yolo.Loss object trainer = ClassificationTrainer(cfg)
trainer = Trainer(model, dataset, criterion, config=cfg)
trainer.run() trainer.run()

@ -0,0 +1,18 @@
import torch
from ultralytics import yolo
class ClassificationValidator(yolo.BaseValidator):
def init_metrics(self):
self.correct = torch.tensor([])
def update_metrics(self, preds, targets):
correct_in_batch = (targets[:, None] == preds).float()
self.correct = torch.cat((self.correct, correct_in_batch))
def get_stats(self):
acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy
top1, top5 = acc.mean(0).tolist()
return {"top1": top1, "top5": top5, "fitness": top5}
Loading…
Cancel
Save