From fbeeb5d1e10ae01950fdf3b000fb806ee8ebf5ee Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Mon, 5 Dec 2022 20:56:41 -0600 Subject: [PATCH] add resuming (#63) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/engine/trainer.py | 82 +++++++++++++++------ ultralytics/yolo/utils/configs/default.yaml | 1 + ultralytics/yolo/utils/files.py | 7 ++ ultralytics/yolo/v8/classify/train.py | 16 +++- ultralytics/yolo/v8/detect/train.py | 4 +- ultralytics/yolo/v8/segment/train.py | 4 +- ultralytics/yolo/v8/segment/val.py | 2 +- 7 files changed, 86 insertions(+), 30 deletions(-) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 0899941..8d461be 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -26,8 +26,7 @@ import ultralytics.yolo.utils.callbacks as callbacks from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr from ultralytics.yolo.utils.checks import check_file, print_args -from ultralytics.yolo.utils.files import increment_path, save_yaml -from ultralytics.yolo.utils.modeling import get_model +from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" @@ -38,6 +37,7 @@ class BaseTrainer: def __init__(self, config=DEFAULT_CONFIG, overrides={}): self.args = self._get_config(config, overrides) + self.check_resume() init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) self.console = LOGGER @@ -50,6 +50,7 @@ class BaseTrainer: self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths self.batch_size = self.args.batch_size self.epochs = self.args.epochs + self.start_epoch = 0 print_args(dict(self.args)) # Save run settings @@ -66,8 +67,6 @@ class BaseTrainer: else: self.data = check_dataset(self.data) self.trainset, self.testset = self.get_dataset(self.data) - if self.args.model: - self.model = self.get_model(self.args.model) self.ema = None # Optimization utils init @@ -136,15 +135,17 @@ class BaseTrainer: self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ") dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) - self.model = self.model.to(self.device) - self.model = DDP(self.model, device_ids=[rank]) def _setup_train(self, rank, world_size): """ Builds dataloaders and optimizer on correct rank process """ - # Optimizer + # model + ckpt = self.setup_model() self.set_model_attributes() + if world_size > 1: + self.model = DDP(self.model, device_ids=[rank]) + # Optimizer self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay self.optimizer = build_optimizer(model=self.model, @@ -158,6 +159,8 @@ class BaseTrainer: else: self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) + self.resume_training(ckpt) + self.scheduler.last_epoch = self.start_epoch - 1 # do not move # dataloaders batch_size = self.batch_size // world_size @@ -174,20 +177,18 @@ class BaseTrainer: def _do_train(self, rank=-1, world_size=1): if world_size > 1: self._setup_ddp(rank, world_size) - else: - self.model = self.model.to(self.device) - self.trigger_callbacks("before_train") self._setup_train(rank, world_size) + self.trigger_callbacks("before_train") - self.epoch = 0 self.epoch_time = None self.epoch_time_start = time.time() self.train_time_start = time.time() nb = len(self.train_loader) # number of batches nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations last_opt_step = -1 - for epoch in range(self.epochs): + for epoch in range(self.start_epoch, self.epochs): + self.epoch = epoch self.trigger_callbacks("on_epoch_start") self.model.train() if rank != -1: @@ -257,11 +258,10 @@ class BaseTrainer: self.save_metrics(metrics=log_vals) # save model - if (not self.args.nosave) or (self.epoch + 1 == self.epochs): + if (not self.args.nosave) or (epoch + 1 == self.epochs): self.save_model() self.trigger_callbacks('on_model_save') - self.epoch += 1 tnow = time.time() self.epoch_time = tnow - self.epoch_time_start self.epoch_time_start = tnow @@ -301,17 +301,21 @@ class BaseTrainer: """ return data["train"], data.get("val") or data.get("test") - def get_model(self, model: Union[str, Path]): + def setup_model(self): """ load/create/download model for any task """ - pretrained = True - if str(model).endswith(".yaml"): + model = self.args.model + pretrained = not (str(model).endswith(".yaml")) + # config + if not pretrained: model = check_file(model) - pretrained = False - return self.load_model(model_cfg=None if pretrained else model, - weights=get_model(model) if pretrained else None, - data=self.data) # model + ckpt = self.load_ckpt(model) if pretrained else None + self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt).to(self.device) # model + return ckpt + + def load_ckpt(self, ckpt): + return torch.load(ckpt, map_location='cpu') def optimizer_step(self): self.scaler.unscale_(self.optimizer) # unscale gradients @@ -350,7 +354,7 @@ class BaseTrainer: if rank in {-1, 0}: self.console.info(text) - def load_model(self, model_cfg, weights, data): + def load_model(self, model_cfg, weights): raise NotImplementedError("This task trainer doesn't support loading cfg files") def get_validator(self): @@ -409,6 +413,40 @@ class BaseTrainer: if f is self.best: self.console.info(f'\nValidating {f}...') + def check_resume(self): + resume = self.args.resume + if resume: + last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run()) + args_yaml = last.parent.parent / 'args.yaml' # train options yaml + if args_yaml.is_file(): + args = self._get_config(args_yaml) # replace + args.model, args.resume, args.exist_ok = str(last), True, True # reinstate + self.args = args + + def resume_training(self, ckpt): + if ckpt is None: + return + best_fitness = 0.0 + start_epoch = ckpt['epoch'] + 1 + if ckpt['optimizer'] is not None: + self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer + best_fitness = ckpt['best_fitness'] + if self.ema and ckpt.get('ema'): + self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA + self.ema.updates = ckpt['updates'] + if self.args.resume: + assert start_epoch > 0, f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ + f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'" + LOGGER.info( + f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs') + if self.epochs < start_epoch: + LOGGER.info( + f"{self.args.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." + ) + self.epochs += ckpt['epoch'] # finetune additional epochs + self.best_fitness = best_fitness + self.start_epoch = start_epoch + def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): # TODO: 1. docstring with example? 2. Move this inside Trainer? or utils? diff --git a/ultralytics/yolo/utils/configs/default.yaml b/ultralytics/yolo/utils/configs/default.yaml index 9dd3ab7..348b397 100644 --- a/ultralytics/yolo/utils/configs/default.yaml +++ b/ultralytics/yolo/utils/configs/default.yaml @@ -33,6 +33,7 @@ overlap_mask: True # masks overlap mask_ratio: 4 # mask downsample ratio # Classification dropout: False # use dropout +resume: False # Val/Test settings ---------------------------------------------------------------------------------------------------- diff --git a/ultralytics/yolo/utils/files.py b/ultralytics/yolo/utils/files.py index 2ae9812..0e97491 100644 --- a/ultralytics/yolo/utils/files.py +++ b/ultralytics/yolo/utils/files.py @@ -1,4 +1,5 @@ import contextlib +import glob import os from datetime import datetime from pathlib import Path @@ -74,3 +75,9 @@ def file_date(path=__file__): # Return human-readable file modification date, i.e. '2021-3-26' t = datetime.fromtimestamp(Path(path).stat().st_mtime) return f'{t.year}-{t.month}-{t.day}' + + +def get_latest_run(search_dir='.'): + # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) + last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) + return max(last_list, key=os.path.getctime) if last_list else '' diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 370c4ad..813278d 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -4,6 +4,7 @@ import torch from ultralytics.yolo import v8 from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer +from ultralytics.yolo.utils.modeling import get_model from ultralytics.yolo.utils.modeling.tasks import ClassificationModel @@ -12,13 +13,13 @@ class ClassificationTrainer(BaseTrainer): def set_model_attributes(self): self.model.names = self.data["names"] - def load_model(self, model_cfg, weights, data): + def load_model(self, model_cfg, weights): # TODO: why treat clf models as unique. We should have clf yamls? if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision model = weights else: - model = ClassificationModel(model_cfg, weights, data["nc"]) - ClassificationModel.reshape_outputs(model, data["nc"]) + model = ClassificationModel(model_cfg, weights, self.data["nc"]) + ClassificationModel.reshape_outputs(model, self.data["nc"]) for m in model.modules(): if not weights and hasattr(m, 'reset_parameters'): m.reset_parameters() @@ -28,6 +29,9 @@ class ClassificationTrainer(BaseTrainer): p.requires_grad = True # for training return model + def load_ckpt(self, ckpt): + return get_model(ckpt) + def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"): return build_classification_dataloader(path=dataset_path, imgsz=self.args.img_size, @@ -46,6 +50,12 @@ class ClassificationTrainer(BaseTrainer): loss = torch.nn.functional.cross_entropy(preds, batch["cls"]) return loss, loss + def check_resume(self): + pass + + def resume_training(self, ckpt): + pass + @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) def train(cfg): diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index 22e9e36..8021b78 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -15,10 +15,10 @@ from .val import DetectionValidator # BaseTrainer python usage class DetectionTrainer(SegmentationTrainer): - def load_model(self, model_cfg, weights, data): + def load_model(self, model_cfg, weights): model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, - nc=data["nc"], + nc=self.data["nc"], anchors=self.args.get("anchors")) if weights: model.load(weights) diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index a5481d2..95bc417 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -26,10 +26,10 @@ class SegmentationTrainer(BaseTrainer): batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 return batch - def load_model(self, model_cfg, weights, data): + def load_model(self, model_cfg, weights): model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, - nc=data["nc"], + nc=self.data["nc"], anchors=self.args.get("anchors")) if weights: model.load(weights) diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 3784fd3..7ada26c 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -242,7 +242,7 @@ class SegmentationValidator(BaseValidator): cls, bboxes, masks, - paths, + paths=paths, fname=self.save_dir / f"val_batch{ni}_labels.jpg", names=self.names)