Add EMA and model checkpointing (#49)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 27d6545117
commit 4291b9c31c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,7 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
from copy import deepcopy
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict, Union
@ -29,6 +30,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.checks import print_args from ultralytics.yolo.utils.checks import print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@ -63,6 +65,7 @@ class BaseTrainer:
self.trainset, self.testset = self.get_dataset(self.data) self.trainset, self.testset = self.get_dataset(self.data)
if self.args.model: if self.args.model:
self.model = self.get_model(self.args.model) self.model = self.get_model(self.args.model)
self.ema = None
# epoch level metrics # epoch level metrics
self.metrics = {} # handle metrics returned by validator self.metrics = {} # handle metrics returned by validator
@ -144,6 +147,7 @@ class BaseTrainer:
self.validator = self.get_validator() self.validator = self.get_validator()
print("created testloader :", rank) print("created testloader :", rank)
self.console.info(self.progress_string()) self.console.info(self.progress_string())
self.ema = ModelEMA(self.model)
def _do_train(self, rank=-1, world_size=1): def _do_train(self, rank=-1, world_size=1):
if world_size > 1: if world_size > 1:
@ -196,6 +200,7 @@ class BaseTrainer:
if rank in [-1, 0]: if rank in [-1, 0]:
# validation # validation
# callback: on_val_start() # callback: on_val_start()
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
self.validate() self.validate()
# callback: on_val_end() # callback: on_val_end()
@ -220,10 +225,10 @@ class BaseTrainer:
ckpt = { ckpt = {
'epoch': self.epoch, 'epoch': self.epoch,
'best_fitness': self.best_fitness, 'best_fitness': self.best_fitness,
'model': None, # deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(), 'model': deepcopy(de_parallel(self.model)).half(),
'ema': None, # deepcopy(ema.ema).half(), 'ema': deepcopy(self.ema.ema).half(),
'updates': None, # ema.updates, 'updates': self.ema.updates,
'optimizer': None, # optimizer.state_dict(), 'optimizer': self.optimizer.state_dict(),
'train_args': self.args, 'train_args': self.args,
'date': datetime.now().isoformat()} 'date': datetime.now().isoformat()}
@ -266,6 +271,8 @@ class BaseTrainer:
self.scaler.step(self.optimizer) self.scaler.step(self.optimizer)
self.scaler.update() self.scaler.update()
self.optimizer.zero_grad() self.optimizer.zero_grad()
if self.ema:
self.ema.update(self.model)
def preprocess_batch(self, batch): def preprocess_batch(self, batch):
""" """

@ -30,19 +30,16 @@ class BaseValidator:
Supports validation of a pre-trained model if passed or a model being trained Supports validation of a pre-trained model if passed or a model being trained
if trainer is passed (trainer gets priority). if trainer is passed (trainer gets priority).
""" """
training = trainer is not None self.training = trainer is not None
self.training = training if self.training:
# trainer = trainer or self.trainer_class.get_trainer() model = trainer.ema.ema or trainer.model
assert training or model is not None, "Either trainer or model is needed for validation"
if training:
model = trainer.model
self.args.half &= self.device.type != 'cpu' self.args.half &= self.device.type != 'cpu'
# NOTE: half() inference in evaluation will make training stuck, # NOTE: half() inference in evaluation will make training stuck,
# so I comment it out for now, I think we can reuse half mode after we add EMA. # so I comment it out for now, I think we can reuse half mode after we add EMA.
# model = model.half() if self.args.half else model model = model.half() if self.args.half else model.float()
else: # TODO: handle this when detectMultiBackend is supported else: # TODO: handle this when detectMultiBackend is supported
assert model is not None, "Either trainer or model is needed for validation"
# model = DetectMultiBacked(model) # model = DetectMultiBacked(model)
pass
# TODO: implement init_model_attributes() # TODO: implement init_model_attributes()
model.eval() model.eval()
@ -50,7 +47,7 @@ class BaseValidator:
loss = 0 loss = 0
n_batches = len(self.dataloader) n_batches = len(self.dataloader)
desc = self.get_desc() desc = self.get_desc()
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format=TQDM_BAR_FORMAT) bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
self.init_metrics(de_parallel(model)) self.init_metrics(de_parallel(model))
with torch.no_grad(): with torch.no_grad():
for batch_i, batch in enumerate(bar): for batch_i, batch in enumerate(bar):
@ -67,7 +64,7 @@ class BaseValidator:
# loss # loss
with dt[2]: with dt[2]:
if training: if self.training:
loss += trainer.criterion(preds, batch)[0] loss += trainer.criterion(preds, batch)[0]
# pre-process predictions # pre-process predictions
@ -82,7 +79,7 @@ class BaseValidator:
self.print_results() self.print_results()
# print speeds # print speeds
if not training: if not self.training:
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz) # shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
self.logger.info( self.logger.info(

@ -232,4 +232,4 @@ class ClassificationModel(BaseModel):
elif nn.Conv2d in types: elif nn.Conv2d in types:
i = types.index(nn.Conv2d) # nn.Conv2d index i = types.index(nn.Conv2d) # nn.Conv2d index
if m[i].out_channels != nc: if m[i].out_channels != nc:
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias) m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)

@ -192,3 +192,34 @@ def is_parallel(model):
def de_parallel(model): def de_parallel(model):
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
return model.module if is_parallel(model) else model return model.module if is_parallel(model) else model
class ModelEMA:
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
# Update EMA parameters
self.updates += 1
d = self.decay(self.updates)
msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point: # true for FP16 and FP32
v *= d
v += (1 - d) * msd[k].detach()
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
copy_attr(self.ema, model, include, exclude)

@ -159,11 +159,11 @@ class SegmentationTrainer(BaseTrainer):
return tcls, tbox, indices, anch, tidxs, xywhn return tcls, tbox, indices, anch, tidxs, xywhn
if self.model.training: if len(preds) == 2: # eval
p, proto, = preds p, proto, = preds
else: else: # len(3) train
p, proto, train_out = preds _, proto, p = preds
p = train_out
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
masks = batch["masks"] masks = batch["masks"]
targets, masks = targets.to(self.device), masks.to(self.device).float() targets, masks = targets.to(self.device), masks.to(self.device).float()

@ -1,5 +1,4 @@
import os import os
from pathlib import Path
import numpy as np import numpy as np
import torch import torch

Loading…
Cancel
Save