From 340376f7a6613d50707183b6ac13c695d036cb36 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Tue, 3 Jan 2023 21:06:22 +0800 Subject: [PATCH] Fix resume (#138) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/nn/tasks.py | 2 ++ ultralytics/yolo/data/dataset.py | 9 +++++++ ultralytics/yolo/engine/trainer.py | 40 +++++++++++++++++++--------- ultralytics/yolo/engine/validator.py | 1 + ultralytics/yolo/utils/dist.py | 7 ++--- ultralytics/yolo/v8/detect/train.py | 4 ++- ultralytics/yolo/v8/detect/val.py | 7 +++-- ultralytics/yolo/v8/segment/train.py | 4 ++- ultralytics/yolo/v8/segment/val.py | 7 +++-- 9 files changed, 55 insertions(+), 26 deletions(-) diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 61c78ba..38063a6 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -293,6 +293,8 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): # Model compatibility updates ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} + if not hasattr(ckpt, 'stride'): + ckpt.stride = torch.tensor([32.]) # Append model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py index 03d03af..d94f7c5 100644 --- a/ultralytics/yolo/data/dataset.py +++ b/ultralytics/yolo/data/dataset.py @@ -136,6 +136,15 @@ class YOLODataset(BaseDataset): batch_idx=True)) return transforms + def close_mosaic(self, hyp): + self.transforms = affine_transforms(self.imgsz, hyp) + self.transforms.append( + Format(bbox_format="xywh", + normalize=True, + return_mask=self.use_segments, + return_keypoint=self.use_keypoints, + batch_idx=True)) + def update_labels_info(self, label): """custom your label format here""" # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index f0f6a80..4006ca9 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -15,6 +15,7 @@ import torch import torch.distributed as dist import torch.nn as nn from omegaconf import OmegaConf # noqa +from omegaconf import open_dict from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import lr_scheduler @@ -90,10 +91,15 @@ class BaseTrainer: # Dirs project = self.args.project or f"runs/{self.args.task}" name = self.args.name or f"{self.args.mode}" - self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True) + self.save_dir = Path( + self.args.get( + "save_dir", + increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True))) self.wdir = self.save_dir / 'weights' # weights dir if RANK in {-1, 0}: self.wdir.mkdir(parents=True, exist_ok=True) # make dir + with open_dict(self.args): + self.args.save_dir = str(self.save_dir) yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # save run args self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths @@ -131,6 +137,7 @@ class BaseTrainer: self.tloss = None self.loss_names = None self.csv = self.save_dir / 'results.csv' + self.plot_idx = [0, 1, 2] # Callbacks self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks @@ -199,7 +206,6 @@ class BaseTrainer: else: self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) - self.resume_training(ckpt) self.scheduler.last_epoch = self.start_epoch - 1 # do not move # dataloaders @@ -211,6 +217,7 @@ class BaseTrainer: metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val") self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()? self.ema = ModelEMA(self.model) + self.resume_training(ckpt) self.run_callbacks("on_pretrain_routine_end") def _do_train(self, rank=-1, world_size=1): @@ -230,6 +237,9 @@ class BaseTrainer: f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' f"Logging results to {colorstr('bold', self.save_dir)}\n" f"Starting training for {self.epochs} epochs...") + if self.args.close_mosaic: + base_idx = (self.epochs - self.args.close_mosaic) * nb + self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) for epoch in range(self.start_epoch, self.epochs): self.epoch = epoch self.run_callbacks("on_train_epoch_start") @@ -237,19 +247,21 @@ class BaseTrainer: if rank != -1: self.train_loader.sampler.set_epoch(epoch) pbar = enumerate(self.train_loader) + # Update dataloader attributes (optional) + if epoch == (self.epochs - self.args.close_mosaic): + self.console.info("Closing dataloader mosaic") + if hasattr(self.train_loader.dataset, 'mosaic'): + self.train_loader.dataset.mosaic = False + if hasattr(self.train_loader.dataset, 'close_mosaic'): + self.train_loader.dataset.close_mosaic(hyp=self.args) + if rank in {-1, 0}: self.console.info(self.progress_string()) - pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT) + pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT) self.tloss = None self.optimizer.zero_grad() for i, batch in pbar: self.run_callbacks("on_train_batch_start") - - # Update dataloader attributes (optional) - if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'): - LOGGER.info("Closing dataloader mosaic") - self.train_loader.dataset.mosaic = False - # Warmup ni = i + nb * epoch if ni <= nw: @@ -289,7 +301,7 @@ class BaseTrainer: ('%11s' * 2 + '%11.4g' * (2 + loss_len)) % (f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])) self.run_callbacks('on_batch_end') - if self.args.plots and ni < 3: + if self.args.plots and ni in self.plot_idx: self.plot_training_samples(batch, ni) self.run_callbacks("on_train_batch_end") @@ -367,7 +379,8 @@ class BaseTrainer: if not pretrained: model = check_file(model) ckpt = self.load_ckpt(model) if pretrained else None - self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt["model"]) # model + weights = ckpt["model"] if isinstance(ckpt, dict) else ckpt # torchvision weights are not dicts + self.model = self.load_model(model_cfg=None if pretrained else model, weights=weights) return ckpt def load_ckpt(self, ckpt): @@ -479,8 +492,9 @@ class BaseTrainer: args_yaml = last.parent.parent / 'args.yaml' # train options yaml if args_yaml.is_file(): args = get_config(args_yaml) # replace - args.model, args.resume, args.exist_ok = str(last), True, True # reinstate + args.model, resume = str(last), True # reinstate self.args = args + self.resume = resume def resume_training(self, ckpt): if ckpt is None: @@ -493,7 +507,7 @@ class BaseTrainer: if self.ema and ckpt.get('ema'): self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA self.ema.updates = ckpt['updates'] - if self.args.resume: + if self.resume: assert start_epoch > 0, \ f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'" diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 61cf6cf..638e828 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -111,6 +111,7 @@ class BaseValidator: self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading self.dataloader = self.dataloader or \ self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size) + self.data = data model.eval() diff --git a/ultralytics/yolo/utils/dist.py b/ultralytics/yolo/utils/dist.py index 1bdd8cb..e99a7a5 100644 --- a/ultralytics/yolo/utils/dist.py +++ b/ultralytics/yolo/utils/dist.py @@ -24,11 +24,12 @@ def find_free_network_port() -> int: def generate_ddp_file(trainer): import_path = '.'.join(str(trainer.__class__).split(".")[1:-1]) - shutil.rmtree(trainer.save_dir) # remove the save_dir - content = f'''overrides = {dict(trainer.args)} \nif __name__ == "__main__": + if not trainer.resume: + shutil.rmtree(trainer.save_dir) # remove the save_dir + content = f'''config = {dict(trainer.args)} \nif __name__ == "__main__": from ultralytics.{import_path} import {trainer.__class__.__name__} - trainer = {trainer.__class__.__name__}(overrides=overrides) + trainer = {trainer.__class__.__name__}(config=config) trainer.train()''' (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) with tempfile.NamedTemporaryFile(prefix="_temp_", diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index db760d5..19a357c 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -1,3 +1,5 @@ +from copy import copy + import hydra import torch import torch.nn as nn @@ -64,7 +66,7 @@ class DetectionTrainer(BaseTrainer): return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, logger=self.console, - args=self.args) + args=copy(self.args)) def criterion(self, preds, batch): if not hasattr(self, 'compute_loss'): diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index c14c097..9f49298 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -42,10 +42,9 @@ class DetectionValidator(BaseValidator): def init_metrics(self, model): head = model.model[-1] if self.training else model.model.model[-1] - if self.data: - self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt') # is COCO dataset - self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) - self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO + self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt') # is COCO dataset + self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) + self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO self.nc = head.nc self.names = model.names self.metrics.names = self.names diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index f8f146d..529cd29 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -1,3 +1,5 @@ +from copy import copy + import hydra import torch import torch.nn as nn @@ -27,7 +29,7 @@ class SegmentationTrainer(v8.detect.DetectionTrainer): return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, logger=self.console, - args=self.args) + args=copy(self.args)) def criterion(self, preds, batch): if not hasattr(self, 'compute_loss'): diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 05eea88..cc90497 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -37,10 +37,9 @@ class SegmentationValidator(DetectionValidator): def init_metrics(self, model): head = model.model[-1] if self.training else model.model.model[-1] - if self.data: - self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt') # is COCO dataset - self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) - self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO + self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt') # is COCO dataset + self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) + self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO self.nc = head.nc self.nm = head.nm if hasattr(head, "nm") else 32 self.names = model.names