add resuming (#63)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Laughing
2022-12-05 20:56:41 -06:00
committed by GitHub
parent de3e6ca54d
commit fbeeb5d1e1
7 changed files with 86 additions and 30 deletions

View File

@ -26,8 +26,7 @@ import ultralytics.yolo.utils.callbacks as callbacks
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
from ultralytics.yolo.utils.checks import check_file, print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@ -38,6 +37,7 @@ class BaseTrainer:
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
self.args = self._get_config(config, overrides)
self.check_resume()
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
self.console = LOGGER
@ -50,6 +50,7 @@ class BaseTrainer:
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.batch_size = self.args.batch_size
self.epochs = self.args.epochs
self.start_epoch = 0
print_args(dict(self.args))
# Save run settings
@ -66,8 +67,6 @@ class BaseTrainer:
else:
self.data = check_dataset(self.data)
self.trainset, self.testset = self.get_dataset(self.data)
if self.args.model:
self.model = self.get_model(self.args.model)
self.ema = None
# Optimization utils init
@ -136,15 +135,17 @@ class BaseTrainer:
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
self.model = self.model.to(self.device)
self.model = DDP(self.model, device_ids=[rank])
def _setup_train(self, rank, world_size):
"""
Builds dataloaders and optimizer on correct rank process
"""
# Optimizer
# model
ckpt = self.setup_model()
self.set_model_attributes()
if world_size > 1:
self.model = DDP(self.model, device_ids=[rank])
# Optimizer
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
self.optimizer = build_optimizer(model=self.model,
@ -158,6 +159,8 @@ class BaseTrainer:
else:
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
# dataloaders
batch_size = self.batch_size // world_size
@ -174,20 +177,18 @@ class BaseTrainer:
def _do_train(self, rank=-1, world_size=1):
if world_size > 1:
self._setup_ddp(rank, world_size)
else:
self.model = self.model.to(self.device)
self.trigger_callbacks("before_train")
self._setup_train(rank, world_size)
self.trigger_callbacks("before_train")
self.epoch = 0
self.epoch_time = None
self.epoch_time_start = time.time()
self.train_time_start = time.time()
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1
for epoch in range(self.epochs):
for epoch in range(self.start_epoch, self.epochs):
self.epoch = epoch
self.trigger_callbacks("on_epoch_start")
self.model.train()
if rank != -1:
@ -257,11 +258,10 @@ class BaseTrainer:
self.save_metrics(metrics=log_vals)
# save model
if (not self.args.nosave) or (self.epoch + 1 == self.epochs):
if (not self.args.nosave) or (epoch + 1 == self.epochs):
self.save_model()
self.trigger_callbacks('on_model_save')
self.epoch += 1
tnow = time.time()
self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow
@ -301,17 +301,21 @@ class BaseTrainer:
"""
return data["train"], data.get("val") or data.get("test")
def get_model(self, model: Union[str, Path]):
def setup_model(self):
"""
load/create/download model for any task
"""
pretrained = True
if str(model).endswith(".yaml"):
model = self.args.model
pretrained = not (str(model).endswith(".yaml"))
# config
if not pretrained:
model = check_file(model)
pretrained = False
return self.load_model(model_cfg=None if pretrained else model,
weights=get_model(model) if pretrained else None,
data=self.data) # model
ckpt = self.load_ckpt(model) if pretrained else None
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt).to(self.device) # model
return ckpt
def load_ckpt(self, ckpt):
return torch.load(ckpt, map_location='cpu')
def optimizer_step(self):
self.scaler.unscale_(self.optimizer) # unscale gradients
@ -350,7 +354,7 @@ class BaseTrainer:
if rank in {-1, 0}:
self.console.info(text)
def load_model(self, model_cfg, weights, data):
def load_model(self, model_cfg, weights):
raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self):
@ -409,6 +413,40 @@ class BaseTrainer:
if f is self.best:
self.console.info(f'\nValidating {f}...')
def check_resume(self):
resume = self.args.resume
if resume:
last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run())
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
if args_yaml.is_file():
args = self._get_config(args_yaml) # replace
args.model, args.resume, args.exist_ok = str(last), True, True # reinstate
self.args = args
def resume_training(self, ckpt):
if ckpt is None:
return
best_fitness = 0.0
start_epoch = ckpt['epoch'] + 1
if ckpt['optimizer'] is not None:
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
best_fitness = ckpt['best_fitness']
if self.ema and ckpt.get('ema'):
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
self.ema.updates = ckpt['updates']
if self.args.resume:
assert start_epoch > 0, f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
LOGGER.info(
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
if self.epochs < start_epoch:
LOGGER.info(
f"{self.args.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
)
self.epochs += ckpt['epoch'] # finetune additional epochs
self.best_fitness = best_fitness
self.start_epoch = start_epoch
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
# TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?