Add EMA and model checkpointing (#49)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 27d6545117
commit 4291b9c31c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,7 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
import os
import time
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Dict, Union
@ -29,6 +30,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.checks import print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@ -63,6 +65,7 @@ class BaseTrainer:
self.trainset, self.testset = self.get_dataset(self.data)
if self.args.model:
self.model = self.get_model(self.args.model)
self.ema = None
# epoch level metrics
self.metrics = {} # handle metrics returned by validator
@ -144,6 +147,7 @@ class BaseTrainer:
self.validator = self.get_validator()
print("created testloader :", rank)
self.console.info(self.progress_string())
self.ema = ModelEMA(self.model)
def _do_train(self, rank=-1, world_size=1):
if world_size > 1:
@ -196,6 +200,7 @@ class BaseTrainer:
if rank in [-1, 0]:
# validation
# callback: on_val_start()
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
self.validate()
# callback: on_val_end()
@ -220,10 +225,10 @@ class BaseTrainer:
ckpt = {
'epoch': self.epoch,
'best_fitness': self.best_fitness,
'model': None, # deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(),
'ema': None, # deepcopy(ema.ema).half(),
'updates': None, # ema.updates,
'optimizer': None, # optimizer.state_dict(),
'model': deepcopy(de_parallel(self.model)).half(),
'ema': deepcopy(self.ema.ema).half(),
'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(),
'train_args': self.args,
'date': datetime.now().isoformat()}
@ -266,6 +271,8 @@ class BaseTrainer:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.ema:
self.ema.update(self.model)
def preprocess_batch(self, batch):
"""

@ -30,19 +30,16 @@ class BaseValidator:
Supports validation of a pre-trained model if passed or a model being trained
if trainer is passed (trainer gets priority).
"""
training = trainer is not None
self.training = training
# trainer = trainer or self.trainer_class.get_trainer()
assert training or model is not None, "Either trainer or model is needed for validation"
if training:
model = trainer.model
self.training = trainer is not None
if self.training:
model = trainer.ema.ema or trainer.model
self.args.half &= self.device.type != 'cpu'
# NOTE: half() inference in evaluation will make training stuck,
# so I comment it out for now, I think we can reuse half mode after we add EMA.
# model = model.half() if self.args.half else model
model = model.half() if self.args.half else model.float()
else: # TODO: handle this when detectMultiBackend is supported
assert model is not None, "Either trainer or model is needed for validation"
# model = DetectMultiBacked(model)
pass
# TODO: implement init_model_attributes()
model.eval()
@ -50,7 +47,7 @@ class BaseValidator:
loss = 0
n_batches = len(self.dataloader)
desc = self.get_desc()
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format=TQDM_BAR_FORMAT)
bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
self.init_metrics(de_parallel(model))
with torch.no_grad():
for batch_i, batch in enumerate(bar):
@ -67,7 +64,7 @@ class BaseValidator:
# loss
with dt[2]:
if training:
if self.training:
loss += trainer.criterion(preds, batch)[0]
# pre-process predictions
@ -82,7 +79,7 @@ class BaseValidator:
self.print_results()
# print speeds
if not training:
if not self.training:
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
self.logger.info(

@ -232,4 +232,4 @@ class ClassificationModel(BaseModel):
elif nn.Conv2d in types:
i = types.index(nn.Conv2d) # nn.Conv2d index
if m[i].out_channels != nc:
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias)
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)

@ -192,3 +192,34 @@ def is_parallel(model):
def de_parallel(model):
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
return model.module if is_parallel(model) else model
class ModelEMA:
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
# Update EMA parameters
self.updates += 1
d = self.decay(self.updates)
msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point: # true for FP16 and FP32
v *= d
v += (1 - d) * msd[k].detach()
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
copy_attr(self.ema, model, include, exclude)

@ -159,11 +159,11 @@ class SegmentationTrainer(BaseTrainer):
return tcls, tbox, indices, anch, tidxs, xywhn
if self.model.training:
if len(preds) == 2: # eval
p, proto, = preds
else:
p, proto, train_out = preds
p = train_out
else: # len(3) train
_, proto, p = preds
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
masks = batch["masks"]
targets, masks = targets.to(self.device), masks.to(self.device).float()

@ -1,5 +1,4 @@
import os
from pathlib import Path
import numpy as np
import torch

Loading…
Cancel
Save