Add EMA and model checkpointing (#49)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2022-11-19 23:37:26 +05:30
committed by GitHub
parent 27d6545117
commit 4291b9c31c
6 changed files with 55 additions and 21 deletions

View File

@ -9,6 +9,7 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
import os
import time
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Dict, Union
@ -29,6 +30,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.checks import print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@ -63,6 +65,7 @@ class BaseTrainer:
self.trainset, self.testset = self.get_dataset(self.data)
if self.args.model:
self.model = self.get_model(self.args.model)
self.ema = None
# epoch level metrics
self.metrics = {} # handle metrics returned by validator
@ -144,6 +147,7 @@ class BaseTrainer:
self.validator = self.get_validator()
print("created testloader :", rank)
self.console.info(self.progress_string())
self.ema = ModelEMA(self.model)
def _do_train(self, rank=-1, world_size=1):
if world_size > 1:
@ -196,6 +200,7 @@ class BaseTrainer:
if rank in [-1, 0]:
# validation
# callback: on_val_start()
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
self.validate()
# callback: on_val_end()
@ -220,10 +225,10 @@ class BaseTrainer:
ckpt = {
'epoch': self.epoch,
'best_fitness': self.best_fitness,
'model': None, # deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(),
'ema': None, # deepcopy(ema.ema).half(),
'updates': None, # ema.updates,
'optimizer': None, # optimizer.state_dict(),
'model': deepcopy(de_parallel(self.model)).half(),
'ema': deepcopy(self.ema.ema).half(),
'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(),
'train_args': self.args,
'date': datetime.now().isoformat()}
@ -266,6 +271,8 @@ class BaseTrainer:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.ema:
self.ema.update(self.model)
def preprocess_batch(self, batch):
"""

View File

@ -30,19 +30,16 @@ class BaseValidator:
Supports validation of a pre-trained model if passed or a model being trained
if trainer is passed (trainer gets priority).
"""
training = trainer is not None
self.training = training
# trainer = trainer or self.trainer_class.get_trainer()
assert training or model is not None, "Either trainer or model is needed for validation"
if training:
model = trainer.model
self.training = trainer is not None
if self.training:
model = trainer.ema.ema or trainer.model
self.args.half &= self.device.type != 'cpu'
# NOTE: half() inference in evaluation will make training stuck,
# so I comment it out for now, I think we can reuse half mode after we add EMA.
# model = model.half() if self.args.half else model
model = model.half() if self.args.half else model.float()
else: # TODO: handle this when detectMultiBackend is supported
assert model is not None, "Either trainer or model is needed for validation"
# model = DetectMultiBacked(model)
pass
# TODO: implement init_model_attributes()
model.eval()
@ -50,7 +47,7 @@ class BaseValidator:
loss = 0
n_batches = len(self.dataloader)
desc = self.get_desc()
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format=TQDM_BAR_FORMAT)
bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
self.init_metrics(de_parallel(model))
with torch.no_grad():
for batch_i, batch in enumerate(bar):
@ -67,7 +64,7 @@ class BaseValidator:
# loss
with dt[2]:
if training:
if self.training:
loss += trainer.criterion(preds, batch)[0]
# pre-process predictions
@ -82,7 +79,7 @@ class BaseValidator:
self.print_results()
# print speeds
if not training:
if not self.training:
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
self.logger.info(