Add warmup and accumulation (#52)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 298287298d
commit d7df1770fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,6 +10,7 @@ from datetime import datetime
from pathlib import Path
from typing import Dict, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
@ -17,6 +18,7 @@ import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler
from tqdm import tqdm
import ultralytics.yolo.utils as utils
@ -26,7 +28,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.checks import print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, one_cycle
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@ -63,6 +65,10 @@ class BaseTrainer:
self.model = self.get_model(self.args.model)
self.ema = None
# Optimization utils init
self.lf = None
self.scheduler = None
# epoch level metrics
self.metrics = {} # handle metrics returned by validator
self.best_fitness = None
@ -131,12 +137,23 @@ class BaseTrainer:
"""
Builds dataloaders and optimizer on correct rank process
"""
# Optimizer
self.set_model_attributes()
accumulate = max(round(self.args.nbs / self.args.batch_size), 1) # accumulate loss before optimizing
self.args.weight_decay *= self.args.batch_size * accumulate / self.args.nbs # scale weight_decay
self.optimizer = build_optimizer(model=self.model,
name=self.args.optimizer,
lr=self.args.lr0,
momentum=self.args.momentum,
decay=self.args.weight_decay)
# Scheduler
if self.args.cos_lr:
self.lf = one_cycle(1, self.args.lrf, self.args.epochs) # cosine 1->hyp['lrf']
else:
self.lf = lambda x: (1 - x / self.args.epochs) * (1.0 - self.args.lrf + self.args.lrf) # linear
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
# dataloaders
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
if rank in {0, -1}:
print(" Creating testloader rank :", rank)
@ -154,10 +171,13 @@ class BaseTrainer:
self.trigger_callbacks("before_train")
self._setup_train(rank)
self.epoch = 1
self.epoch = 0
self.epoch_time = None
self.epoch_time_start = time.time()
self.train_time_start = time.time()
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1
for epoch in range(self.args.epochs):
self.trigger_callbacks("on_epoch_start")
self.model.train()
@ -170,7 +190,18 @@ class BaseTrainer:
# forward
batch = self.preprocess_batch(batch)
# TODO: warmup, multiscale
# warmup
ni = i + nb * epoch
if ni <= nw:
xi = [0, nw] # x interp
accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.args.batch_size]).round())
for j, x in enumerate(self.optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
if 'momentum' in x:
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
preds = self.model(batch["img"])
self.loss, self.loss_items = self.criterion(preds, batch)
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
@ -181,7 +212,9 @@ class BaseTrainer:
self.scaler.scale(self.loss).backward()
# optimize
self.optimizer_step()
if ni - last_opt_step >= accumulate:
self.optimizer_step()
last_opt_step = ni
# log
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)

@ -27,6 +27,7 @@ single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training
shuffle: True
rect: False # support rectangular training
cos_lr: False # Use cosine LR scheduler
overlap_mask: True # Segmentation masks overlap
mask_ratio: 4 # Segmentation mask downsample ratio
@ -71,6 +72,7 @@ mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)
label_smoothing: 0.0
nbs: 64 # nominal batch size
# anchors: 3
# Hydra configs --------------------------------------------------------------------------------------------------------

@ -194,6 +194,11 @@ def de_parallel(model):
return model.module if is_parallel(model) else model
def one_cycle(y1=0.0, y2=1.0, steps=100):
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
class ModelEMA:
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)

Loading…
Cancel
Save