Add warmup and accumulation (#52)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 298287298d
commit d7df1770fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,6 +10,7 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict, Union
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
@ -17,6 +18,7 @@ import torch.nn as nn
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from torch.cuda import amp from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler
from tqdm import tqdm from tqdm import tqdm
import ultralytics.yolo.utils as utils import ultralytics.yolo.utils as utils
@ -26,7 +28,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.checks import print_args from ultralytics.yolo.utils.checks import print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, one_cycle
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@ -63,6 +65,10 @@ class BaseTrainer:
self.model = self.get_model(self.args.model) self.model = self.get_model(self.args.model)
self.ema = None self.ema = None
# Optimization utils init
self.lf = None
self.scheduler = None
# epoch level metrics # epoch level metrics
self.metrics = {} # handle metrics returned by validator self.metrics = {} # handle metrics returned by validator
self.best_fitness = None self.best_fitness = None
@ -131,12 +137,23 @@ class BaseTrainer:
""" """
Builds dataloaders and optimizer on correct rank process Builds dataloaders and optimizer on correct rank process
""" """
# Optimizer
self.set_model_attributes() self.set_model_attributes()
accumulate = max(round(self.args.nbs / self.args.batch_size), 1) # accumulate loss before optimizing
self.args.weight_decay *= self.args.batch_size * accumulate / self.args.nbs # scale weight_decay
self.optimizer = build_optimizer(model=self.model, self.optimizer = build_optimizer(model=self.model,
name=self.args.optimizer, name=self.args.optimizer,
lr=self.args.lr0, lr=self.args.lr0,
momentum=self.args.momentum, momentum=self.args.momentum,
decay=self.args.weight_decay) decay=self.args.weight_decay)
# Scheduler
if self.args.cos_lr:
self.lf = one_cycle(1, self.args.lrf, self.args.epochs) # cosine 1->hyp['lrf']
else:
self.lf = lambda x: (1 - x / self.args.epochs) * (1.0 - self.args.lrf + self.args.lrf) # linear
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
# dataloaders
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank) self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
if rank in {0, -1}: if rank in {0, -1}:
print(" Creating testloader rank :", rank) print(" Creating testloader rank :", rank)
@ -154,10 +171,13 @@ class BaseTrainer:
self.trigger_callbacks("before_train") self.trigger_callbacks("before_train")
self._setup_train(rank) self._setup_train(rank)
self.epoch = 1 self.epoch = 0
self.epoch_time = None self.epoch_time = None
self.epoch_time_start = time.time() self.epoch_time_start = time.time()
self.train_time_start = time.time() self.train_time_start = time.time()
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1
for epoch in range(self.args.epochs): for epoch in range(self.args.epochs):
self.trigger_callbacks("on_epoch_start") self.trigger_callbacks("on_epoch_start")
self.model.train() self.model.train()
@ -170,7 +190,18 @@ class BaseTrainer:
# forward # forward
batch = self.preprocess_batch(batch) batch = self.preprocess_batch(batch)
# TODO: warmup, multiscale # warmup
ni = i + nb * epoch
if ni <= nw:
xi = [0, nw] # x interp
accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.args.batch_size]).round())
for j, x in enumerate(self.optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
if 'momentum' in x:
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
preds = self.model(batch["img"]) preds = self.model(batch["img"])
self.loss, self.loss_items = self.criterion(preds, batch) self.loss, self.loss_items = self.criterion(preds, batch)
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
@ -181,7 +212,9 @@ class BaseTrainer:
self.scaler.scale(self.loss).backward() self.scaler.scale(self.loss).backward()
# optimize # optimize
self.optimizer_step() if ni - last_opt_step >= accumulate:
self.optimizer_step()
last_opt_step = ni
# log # log
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)

@ -27,6 +27,7 @@ single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training image_weights: False # use weighted image selection for training
shuffle: True shuffle: True
rect: False # support rectangular training rect: False # support rectangular training
cos_lr: False # Use cosine LR scheduler
overlap_mask: True # Segmentation masks overlap overlap_mask: True # Segmentation masks overlap
mask_ratio: 4 # Segmentation mask downsample ratio mask_ratio: 4 # Segmentation mask downsample ratio
@ -71,6 +72,7 @@ mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability) mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability) copy_paste: 0.0 # segment copy-paste (probability)
label_smoothing: 0.0 label_smoothing: 0.0
nbs: 64 # nominal batch size
# anchors: 3 # anchors: 3
# Hydra configs -------------------------------------------------------------------------------------------------------- # Hydra configs --------------------------------------------------------------------------------------------------------

@ -194,6 +194,11 @@ def de_parallel(model):
return model.module if is_parallel(model) else model return model.module if is_parallel(model) else model
def one_cycle(y1=0.0, y2=1.0, steps=100):
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
class ModelEMA: class ModelEMA:
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers) Keeps a moving average of everything in the model state_dict (parameters and buffers)

Loading…
Cancel
Save