add resuming (#63)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		| @ -26,8 +26,7 @@ import ultralytics.yolo.utils.callbacks as callbacks | ||||
| from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml | ||||
| from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr | ||||
| from ultralytics.yolo.utils.checks import check_file, print_args | ||||
| from ultralytics.yolo.utils.files import increment_path, save_yaml | ||||
| from ultralytics.yolo.utils.modeling import get_model | ||||
| from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml | ||||
| from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer | ||||
|  | ||||
| DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" | ||||
| @ -38,6 +37,7 @@ class BaseTrainer: | ||||
|  | ||||
|     def __init__(self, config=DEFAULT_CONFIG, overrides={}): | ||||
|         self.args = self._get_config(config, overrides) | ||||
|         self.check_resume() | ||||
|         init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) | ||||
|  | ||||
|         self.console = LOGGER | ||||
| @ -50,6 +50,7 @@ class BaseTrainer: | ||||
|         self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'  # checkpoint paths | ||||
|         self.batch_size = self.args.batch_size | ||||
|         self.epochs = self.args.epochs | ||||
|         self.start_epoch = 0 | ||||
|         print_args(dict(self.args)) | ||||
|  | ||||
|         # Save run settings | ||||
| @ -66,8 +67,6 @@ class BaseTrainer: | ||||
|         else: | ||||
|             self.data = check_dataset(self.data) | ||||
|         self.trainset, self.testset = self.get_dataset(self.data) | ||||
|         if self.args.model: | ||||
|             self.model = self.get_model(self.args.model) | ||||
|         self.ema = None | ||||
|  | ||||
|         # Optimization utils init | ||||
| @ -136,15 +135,17 @@ class BaseTrainer: | ||||
|         self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ") | ||||
|  | ||||
|         dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) | ||||
|         self.model = self.model.to(self.device) | ||||
|         self.model = DDP(self.model, device_ids=[rank]) | ||||
|  | ||||
|     def _setup_train(self, rank, world_size): | ||||
|         """ | ||||
|         Builds dataloaders and optimizer on correct rank process | ||||
|         """ | ||||
|         # Optimizer | ||||
|         # model | ||||
|         ckpt = self.setup_model() | ||||
|         self.set_model_attributes() | ||||
|         if world_size > 1: | ||||
|             self.model = DDP(self.model, device_ids=[rank]) | ||||
|         # Optimizer | ||||
|         self.accumulate = max(round(self.args.nbs / self.batch_size), 1)  # accumulate loss before optimizing | ||||
|         self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs  # scale weight_decay | ||||
|         self.optimizer = build_optimizer(model=self.model, | ||||
| @ -158,6 +159,8 @@ class BaseTrainer: | ||||
|         else: | ||||
|             self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf  # linear | ||||
|         self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) | ||||
|         self.resume_training(ckpt) | ||||
|         self.scheduler.last_epoch = self.start_epoch - 1  # do not move | ||||
|  | ||||
|         # dataloaders | ||||
|         batch_size = self.batch_size // world_size | ||||
| @ -174,20 +177,18 @@ class BaseTrainer: | ||||
|     def _do_train(self, rank=-1, world_size=1): | ||||
|         if world_size > 1: | ||||
|             self._setup_ddp(rank, world_size) | ||||
|         else: | ||||
|             self.model = self.model.to(self.device) | ||||
|  | ||||
|         self.trigger_callbacks("before_train") | ||||
|         self._setup_train(rank, world_size) | ||||
|         self.trigger_callbacks("before_train") | ||||
|  | ||||
|         self.epoch = 0 | ||||
|         self.epoch_time = None | ||||
|         self.epoch_time_start = time.time() | ||||
|         self.train_time_start = time.time() | ||||
|         nb = len(self.train_loader)  # number of batches | ||||
|         nw = max(round(self.args.warmup_epochs * nb), 100)  # number of warmup iterations | ||||
|         last_opt_step = -1 | ||||
|         for epoch in range(self.epochs): | ||||
|         for epoch in range(self.start_epoch, self.epochs): | ||||
|             self.epoch = epoch | ||||
|             self.trigger_callbacks("on_epoch_start") | ||||
|             self.model.train() | ||||
|             if rank != -1: | ||||
| @ -257,11 +258,10 @@ class BaseTrainer: | ||||
|                 self.save_metrics(metrics=log_vals) | ||||
|  | ||||
|                 # save model | ||||
|                 if (not self.args.nosave) or (self.epoch + 1 == self.epochs): | ||||
|                 if (not self.args.nosave) or (epoch + 1 == self.epochs): | ||||
|                     self.save_model() | ||||
|                     self.trigger_callbacks('on_model_save') | ||||
|  | ||||
|             self.epoch += 1 | ||||
|             tnow = time.time() | ||||
|             self.epoch_time = tnow - self.epoch_time_start | ||||
|             self.epoch_time_start = tnow | ||||
| @ -301,17 +301,21 @@ class BaseTrainer: | ||||
|         """ | ||||
|         return data["train"], data.get("val") or data.get("test") | ||||
|  | ||||
|     def get_model(self, model: Union[str, Path]): | ||||
|     def setup_model(self): | ||||
|         """ | ||||
|         load/create/download model for any task | ||||
|         """ | ||||
|         pretrained = True | ||||
|         if str(model).endswith(".yaml"): | ||||
|         model = self.args.model | ||||
|         pretrained = not (str(model).endswith(".yaml")) | ||||
|         # config | ||||
|         if not pretrained: | ||||
|             model = check_file(model) | ||||
|             pretrained = False | ||||
|         return self.load_model(model_cfg=None if pretrained else model, | ||||
|                                weights=get_model(model) if pretrained else None, | ||||
|                                data=self.data)  # model | ||||
|         ckpt = self.load_ckpt(model) if pretrained else None | ||||
|         self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt).to(self.device)  # model | ||||
|         return ckpt | ||||
|  | ||||
|     def load_ckpt(self, ckpt): | ||||
|         return torch.load(ckpt, map_location='cpu') | ||||
|  | ||||
|     def optimizer_step(self): | ||||
|         self.scaler.unscale_(self.optimizer)  # unscale gradients | ||||
| @ -350,7 +354,7 @@ class BaseTrainer: | ||||
|         if rank in {-1, 0}: | ||||
|             self.console.info(text) | ||||
|  | ||||
|     def load_model(self, model_cfg, weights, data): | ||||
|     def load_model(self, model_cfg, weights): | ||||
|         raise NotImplementedError("This task trainer doesn't support loading cfg files") | ||||
|  | ||||
|     def get_validator(self): | ||||
| @ -409,6 +413,40 @@ class BaseTrainer: | ||||
|                 if f is self.best: | ||||
|                     self.console.info(f'\nValidating {f}...') | ||||
|  | ||||
|     def check_resume(self): | ||||
|         resume = self.args.resume | ||||
|         if resume: | ||||
|             last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run()) | ||||
|             args_yaml = last.parent.parent / 'args.yaml'  # train options yaml | ||||
|             if args_yaml.is_file(): | ||||
|                 args = self._get_config(args_yaml)  # replace | ||||
|             args.model, args.resume, args.exist_ok = str(last), True, True  # reinstate | ||||
|             self.args = args | ||||
|  | ||||
|     def resume_training(self, ckpt): | ||||
|         if ckpt is None: | ||||
|             return | ||||
|         best_fitness = 0.0 | ||||
|         start_epoch = ckpt['epoch'] + 1 | ||||
|         if ckpt['optimizer'] is not None: | ||||
|             self.optimizer.load_state_dict(ckpt['optimizer'])  # optimizer | ||||
|             best_fitness = ckpt['best_fitness'] | ||||
|         if self.ema and ckpt.get('ema'): | ||||
|             self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict())  # EMA | ||||
|             self.ema.updates = ckpt['updates'] | ||||
|         if self.args.resume: | ||||
|             assert start_epoch > 0, f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ | ||||
|                                     f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'" | ||||
|             LOGGER.info( | ||||
|                 f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs') | ||||
|         if self.epochs < start_epoch: | ||||
|             LOGGER.info( | ||||
|                 f"{self.args.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." | ||||
|             ) | ||||
|             self.epochs += ckpt['epoch']  # finetune additional epochs | ||||
|         self.best_fitness = best_fitness | ||||
|         self.start_epoch = start_epoch | ||||
|  | ||||
|  | ||||
| def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): | ||||
|     # TODO: 1. docstring with example? 2. Move this inside Trainer? or utils? | ||||
|  | ||||
| @ -33,6 +33,7 @@ overlap_mask: True  # masks overlap | ||||
| mask_ratio: 4  # mask downsample ratio | ||||
| # Classification | ||||
| dropout: False # use dropout | ||||
| resume: False | ||||
|  | ||||
|  | ||||
| # Val/Test settings ---------------------------------------------------------------------------------------------------- | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| import contextlib | ||||
| import glob | ||||
| import os | ||||
| from datetime import datetime | ||||
| from pathlib import Path | ||||
| @ -74,3 +75,9 @@ def file_date(path=__file__): | ||||
|     # Return human-readable file modification date, i.e. '2021-3-26' | ||||
|     t = datetime.fromtimestamp(Path(path).stat().st_mtime) | ||||
|     return f'{t.year}-{t.month}-{t.day}' | ||||
|  | ||||
|  | ||||
| def get_latest_run(search_dir='.'): | ||||
|     # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) | ||||
|     last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) | ||||
|     return max(last_list, key=os.path.getctime) if last_list else '' | ||||
|  | ||||
| @ -4,6 +4,7 @@ import torch | ||||
| from ultralytics.yolo import v8 | ||||
| from ultralytics.yolo.data import build_classification_dataloader | ||||
| from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer | ||||
| from ultralytics.yolo.utils.modeling import get_model | ||||
| from ultralytics.yolo.utils.modeling.tasks import ClassificationModel | ||||
|  | ||||
|  | ||||
| @ -12,13 +13,13 @@ class ClassificationTrainer(BaseTrainer): | ||||
|     def set_model_attributes(self): | ||||
|         self.model.names = self.data["names"] | ||||
|  | ||||
|     def load_model(self, model_cfg, weights, data): | ||||
|     def load_model(self, model_cfg, weights): | ||||
|         # TODO: why treat clf models as unique. We should have clf yamls? | ||||
|         if weights and not weights.__class__.__name__.startswith("yolo"):  # torchvision | ||||
|             model = weights | ||||
|         else: | ||||
|             model = ClassificationModel(model_cfg, weights, data["nc"]) | ||||
|         ClassificationModel.reshape_outputs(model, data["nc"]) | ||||
|             model = ClassificationModel(model_cfg, weights, self.data["nc"]) | ||||
|         ClassificationModel.reshape_outputs(model, self.data["nc"]) | ||||
|         for m in model.modules(): | ||||
|             if not weights and hasattr(m, 'reset_parameters'): | ||||
|                 m.reset_parameters() | ||||
| @ -28,6 +29,9 @@ class ClassificationTrainer(BaseTrainer): | ||||
|             p.requires_grad = True  # for training | ||||
|         return model | ||||
|  | ||||
|     def load_ckpt(self, ckpt): | ||||
|         return get_model(ckpt) | ||||
|  | ||||
|     def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"): | ||||
|         return build_classification_dataloader(path=dataset_path, | ||||
|                                                imgsz=self.args.img_size, | ||||
| @ -46,6 +50,12 @@ class ClassificationTrainer(BaseTrainer): | ||||
|         loss = torch.nn.functional.cross_entropy(preds, batch["cls"]) | ||||
|         return loss, loss | ||||
|  | ||||
|     def check_resume(self): | ||||
|         pass | ||||
|  | ||||
|     def resume_training(self, ckpt): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) | ||||
| def train(cfg): | ||||
|  | ||||
| @ -15,10 +15,10 @@ from .val import DetectionValidator | ||||
| # BaseTrainer python usage | ||||
| class DetectionTrainer(SegmentationTrainer): | ||||
|  | ||||
|     def load_model(self, model_cfg, weights, data): | ||||
|     def load_model(self, model_cfg, weights): | ||||
|         model = DetectionModel(model_cfg or weights["model"].yaml, | ||||
|                                ch=3, | ||||
|                                nc=data["nc"], | ||||
|                                nc=self.data["nc"], | ||||
|                                anchors=self.args.get("anchors")) | ||||
|         if weights: | ||||
|             model.load(weights) | ||||
|  | ||||
| @ -26,10 +26,10 @@ class SegmentationTrainer(BaseTrainer): | ||||
|         batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 | ||||
|         return batch | ||||
|  | ||||
|     def load_model(self, model_cfg, weights, data): | ||||
|     def load_model(self, model_cfg, weights): | ||||
|         model = SegmentationModel(model_cfg or weights["model"].yaml, | ||||
|                                   ch=3, | ||||
|                                   nc=data["nc"], | ||||
|                                   nc=self.data["nc"], | ||||
|                                   anchors=self.args.get("anchors")) | ||||
|         if weights: | ||||
|             model.load(weights) | ||||
|  | ||||
| @ -242,7 +242,7 @@ class SegmentationValidator(BaseValidator): | ||||
|                               cls, | ||||
|                               bboxes, | ||||
|                               masks, | ||||
|                               paths, | ||||
|                               paths=paths, | ||||
|                               fname=self.save_dir / f"val_batch{ni}_labels.jpg", | ||||
|                               names=self.names) | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user