From 4291b9c31ca7a3789d052ede91152b7e1e953332 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Sat, 19 Nov 2022 23:37:26 +0530 Subject: [PATCH] Add EMA and model checkpointing (#49) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/yolo/engine/trainer.py | 15 +++++++++--- ultralytics/yolo/engine/validator.py | 19 ++++++--------- ultralytics/yolo/utils/modeling/tasks.py | 2 +- ultralytics/yolo/utils/torch_utils.py | 31 ++++++++++++++++++++++++ ultralytics/yolo/v8/segment/train.py | 8 +++--- ultralytics/yolo/v8/segment/val.py | 1 - 6 files changed, 55 insertions(+), 21 deletions(-) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index b007de5..9795957 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -9,6 +9,7 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo import os import time from collections import defaultdict +from copy import deepcopy from datetime import datetime from pathlib import Path from typing import Dict, Union @@ -29,6 +30,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT from ultralytics.yolo.utils.checks import print_args from ultralytics.yolo.utils.files import increment_path, save_yaml from ultralytics.yolo.utils.modeling import get_model +from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" @@ -63,6 +65,7 @@ class BaseTrainer: self.trainset, self.testset = self.get_dataset(self.data) if self.args.model: self.model = self.get_model(self.args.model) + self.ema = None # epoch level metrics self.metrics = {} # handle metrics returned by validator @@ -144,6 +147,7 @@ class BaseTrainer: self.validator = self.get_validator() print("created testloader :", rank) self.console.info(self.progress_string()) + self.ema = ModelEMA(self.model) def _do_train(self, rank=-1, world_size=1): if world_size > 1: @@ -196,6 +200,7 @@ class BaseTrainer: if rank in [-1, 0]: # validation # callback: on_val_start() + self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) self.validate() # callback: on_val_end() @@ -220,10 +225,10 @@ class BaseTrainer: ckpt = { 'epoch': self.epoch, 'best_fitness': self.best_fitness, - 'model': None, # deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(), - 'ema': None, # deepcopy(ema.ema).half(), - 'updates': None, # ema.updates, - 'optimizer': None, # optimizer.state_dict(), + 'model': deepcopy(de_parallel(self.model)).half(), + 'ema': deepcopy(self.ema.ema).half(), + 'updates': self.ema.updates, + 'optimizer': self.optimizer.state_dict(), 'train_args': self.args, 'date': datetime.now().isoformat()} @@ -266,6 +271,8 @@ class BaseTrainer: self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() + if self.ema: + self.ema.update(self.model) def preprocess_batch(self, batch): """ diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index eeda7bf..e60f086 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -30,19 +30,16 @@ class BaseValidator: Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer gets priority). """ - training = trainer is not None - self.training = training - # trainer = trainer or self.trainer_class.get_trainer() - assert training or model is not None, "Either trainer or model is needed for validation" - if training: - model = trainer.model + self.training = trainer is not None + if self.training: + model = trainer.ema.ema or trainer.model self.args.half &= self.device.type != 'cpu' # NOTE: half() inference in evaluation will make training stuck, # so I comment it out for now, I think we can reuse half mode after we add EMA. - # model = model.half() if self.args.half else model + model = model.half() if self.args.half else model.float() else: # TODO: handle this when detectMultiBackend is supported + assert model is not None, "Either trainer or model is needed for validation" # model = DetectMultiBacked(model) - pass # TODO: implement init_model_attributes() model.eval() @@ -50,7 +47,7 @@ class BaseValidator: loss = 0 n_batches = len(self.dataloader) desc = self.get_desc() - bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format=TQDM_BAR_FORMAT) + bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT) self.init_metrics(de_parallel(model)) with torch.no_grad(): for batch_i, batch in enumerate(bar): @@ -67,7 +64,7 @@ class BaseValidator: # loss with dt[2]: - if training: + if self.training: loss += trainer.criterion(preds, batch)[0] # pre-process predictions @@ -82,7 +79,7 @@ class BaseValidator: self.print_results() # print speeds - if not training: + if not self.training: t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image # shape = (self.dataloader.batch_size, 3, imgsz, imgsz) self.logger.info( diff --git a/ultralytics/yolo/utils/modeling/tasks.py b/ultralytics/yolo/utils/modeling/tasks.py index c6c82b5..70ef354 100644 --- a/ultralytics/yolo/utils/modeling/tasks.py +++ b/ultralytics/yolo/utils/modeling/tasks.py @@ -232,4 +232,4 @@ class ClassificationModel(BaseModel): elif nn.Conv2d in types: i = types.index(nn.Conv2d) # nn.Conv2d index if m[i].out_channels != nc: - m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias) + m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None) diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index 795e572..3dec581 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -192,3 +192,34 @@ def is_parallel(model): def de_parallel(model): # De-parallelize a model: returns single-GPU model if model is of type DP or DDP return model.module if is_parallel(model) else model + + +class ModelEMA: + """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models + Keeps a moving average of everything in the model state_dict (parameters and buffers) + For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + """ + + def __init__(self, model, decay=0.9999, tau=2000, updates=0): + # Create EMA + self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA + self.updates = updates # number of EMA updates + self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def update(self, model): + # Update EMA parameters + self.updates += 1 + d = self.decay(self.updates) + + msd = de_parallel(model).state_dict() # model state_dict + for k, v in self.ema.state_dict().items(): + if v.dtype.is_floating_point: # true for FP16 and FP32 + v *= d + v += (1 - d) * msd[k].detach() + # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32' + + def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): + # Update EMA attributes + copy_attr(self.ema, model, include, exclude) diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 8372b1b..9949629 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -159,11 +159,11 @@ class SegmentationTrainer(BaseTrainer): return tcls, tbox, indices, anch, tidxs, xywhn - if self.model.training: + if len(preds) == 2: # eval p, proto, = preds - else: - p, proto, train_out = preds - p = train_out + else: # len(3) train + _, proto, p = preds + targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) masks = batch["masks"] targets, masks = targets.to(self.device), masks.to(self.device).float() diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index f4a526f..372f306 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -1,5 +1,4 @@ import os -from pathlib import Path import numpy as np import torch