Fix resume (#138)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		| @ -293,6 +293,8 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): | ||||
|  | ||||
|         # Model compatibility updates | ||||
|         ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} | ||||
|         if not hasattr(ckpt, 'stride'): | ||||
|             ckpt.stride = torch.tensor([32.]) | ||||
|  | ||||
|         # Append | ||||
|         model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval())  # model in eval mode | ||||
|  | ||||
| @ -136,6 +136,15 @@ class YOLODataset(BaseDataset): | ||||
|                    batch_idx=True)) | ||||
|         return transforms | ||||
|  | ||||
|     def close_mosaic(self, hyp): | ||||
|         self.transforms = affine_transforms(self.imgsz, hyp) | ||||
|         self.transforms.append( | ||||
|             Format(bbox_format="xywh", | ||||
|                    normalize=True, | ||||
|                    return_mask=self.use_segments, | ||||
|                    return_keypoint=self.use_keypoints, | ||||
|                    batch_idx=True)) | ||||
|  | ||||
|     def update_labels_info(self, label): | ||||
|         """custom your label format here""" | ||||
|         # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label | ||||
|  | ||||
| @ -15,6 +15,7 @@ import torch | ||||
| import torch.distributed as dist | ||||
| import torch.nn as nn | ||||
| from omegaconf import OmegaConf  # noqa | ||||
| from omegaconf import open_dict | ||||
| from torch.cuda import amp | ||||
| from torch.nn.parallel import DistributedDataParallel as DDP | ||||
| from torch.optim import lr_scheduler | ||||
| @ -90,10 +91,15 @@ class BaseTrainer: | ||||
|         # Dirs | ||||
|         project = self.args.project or f"runs/{self.args.task}" | ||||
|         name = self.args.name or f"{self.args.mode}" | ||||
|         self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True) | ||||
|         self.save_dir = Path( | ||||
|             self.args.get( | ||||
|                 "save_dir", | ||||
|                 increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True))) | ||||
|         self.wdir = self.save_dir / 'weights'  # weights dir | ||||
|         if RANK in {-1, 0}: | ||||
|             self.wdir.mkdir(parents=True, exist_ok=True)  # make dir | ||||
|             with open_dict(self.args): | ||||
|                 self.args.save_dir = str(self.save_dir) | ||||
|             yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))  # save run args | ||||
|         self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'  # checkpoint paths | ||||
|  | ||||
| @ -131,6 +137,7 @@ class BaseTrainer: | ||||
|         self.tloss = None | ||||
|         self.loss_names = None | ||||
|         self.csv = self.save_dir / 'results.csv' | ||||
|         self.plot_idx = [0, 1, 2] | ||||
|  | ||||
|         # Callbacks | ||||
|         self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()})  # add callbacks | ||||
| @ -199,7 +206,6 @@ class BaseTrainer: | ||||
|         else: | ||||
|             self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf  # linear | ||||
|         self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) | ||||
|         self.resume_training(ckpt) | ||||
|         self.scheduler.last_epoch = self.start_epoch - 1  # do not move | ||||
|  | ||||
|         # dataloaders | ||||
| @ -211,6 +217,7 @@ class BaseTrainer: | ||||
|             metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val") | ||||
|             self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))  # TODO: init metrics for plot_results()? | ||||
|             self.ema = ModelEMA(self.model) | ||||
|         self.resume_training(ckpt) | ||||
|         self.run_callbacks("on_pretrain_routine_end") | ||||
|  | ||||
|     def _do_train(self, rank=-1, world_size=1): | ||||
| @ -230,6 +237,9 @@ class BaseTrainer: | ||||
|                  f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' | ||||
|                  f"Logging results to {colorstr('bold', self.save_dir)}\n" | ||||
|                  f"Starting training for {self.epochs} epochs...") | ||||
|         if self.args.close_mosaic: | ||||
|             base_idx = (self.epochs - self.args.close_mosaic) * nb | ||||
|             self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) | ||||
|         for epoch in range(self.start_epoch, self.epochs): | ||||
|             self.epoch = epoch | ||||
|             self.run_callbacks("on_train_epoch_start") | ||||
| @ -237,19 +247,21 @@ class BaseTrainer: | ||||
|             if rank != -1: | ||||
|                 self.train_loader.sampler.set_epoch(epoch) | ||||
|             pbar = enumerate(self.train_loader) | ||||
|             # Update dataloader attributes (optional) | ||||
|             if epoch == (self.epochs - self.args.close_mosaic): | ||||
|                 self.console.info("Closing dataloader mosaic") | ||||
|                 if hasattr(self.train_loader.dataset, 'mosaic'): | ||||
|                     self.train_loader.dataset.mosaic = False | ||||
|                 if hasattr(self.train_loader.dataset, 'close_mosaic'): | ||||
|                     self.train_loader.dataset.close_mosaic(hyp=self.args) | ||||
|  | ||||
|             if rank in {-1, 0}: | ||||
|                 self.console.info(self.progress_string()) | ||||
|                 pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT) | ||||
|                 pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT) | ||||
|             self.tloss = None | ||||
|             self.optimizer.zero_grad() | ||||
|             for i, batch in pbar: | ||||
|                 self.run_callbacks("on_train_batch_start") | ||||
|  | ||||
|                 # Update dataloader attributes (optional) | ||||
|                 if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'): | ||||
|                     LOGGER.info("Closing dataloader mosaic") | ||||
|                     self.train_loader.dataset.mosaic = False | ||||
|  | ||||
|                 # Warmup | ||||
|                 ni = i + nb * epoch | ||||
|                 if ni <= nw: | ||||
| @ -289,7 +301,7 @@ class BaseTrainer: | ||||
|                         ('%11s' * 2 + '%11.4g' * (2 + loss_len)) % | ||||
|                         (f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])) | ||||
|                     self.run_callbacks('on_batch_end') | ||||
|                     if self.args.plots and ni < 3: | ||||
|                     if self.args.plots and ni in self.plot_idx: | ||||
|                         self.plot_training_samples(batch, ni) | ||||
|  | ||||
|                 self.run_callbacks("on_train_batch_end") | ||||
| @ -367,7 +379,8 @@ class BaseTrainer: | ||||
|         if not pretrained: | ||||
|             model = check_file(model) | ||||
|         ckpt = self.load_ckpt(model) if pretrained else None | ||||
|         self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt["model"])  # model | ||||
|         weights = ckpt["model"] if isinstance(ckpt, dict) else ckpt  # torchvision weights are not dicts | ||||
|         self.model = self.load_model(model_cfg=None if pretrained else model, weights=weights) | ||||
|         return ckpt | ||||
|  | ||||
|     def load_ckpt(self, ckpt): | ||||
| @ -479,8 +492,9 @@ class BaseTrainer: | ||||
|             args_yaml = last.parent.parent / 'args.yaml'  # train options yaml | ||||
|             if args_yaml.is_file(): | ||||
|                 args = get_config(args_yaml)  # replace | ||||
|             args.model, args.resume, args.exist_ok = str(last), True, True  # reinstate | ||||
|             args.model, resume = str(last), True  # reinstate | ||||
|             self.args = args | ||||
|         self.resume = resume | ||||
|  | ||||
|     def resume_training(self, ckpt): | ||||
|         if ckpt is None: | ||||
| @ -493,7 +507,7 @@ class BaseTrainer: | ||||
|         if self.ema and ckpt.get('ema'): | ||||
|             self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict())  # EMA | ||||
|             self.ema.updates = ckpt['updates'] | ||||
|         if self.args.resume: | ||||
|         if self.resume: | ||||
|             assert start_epoch > 0, \ | ||||
|                 f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ | ||||
|                 f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'" | ||||
|  | ||||
| @ -111,6 +111,7 @@ class BaseValidator: | ||||
|                 self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading | ||||
|             self.dataloader = self.dataloader or \ | ||||
|                               self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size) | ||||
|             self.data = data | ||||
|  | ||||
|         model.eval() | ||||
|  | ||||
|  | ||||
| @ -24,11 +24,12 @@ def find_free_network_port() -> int: | ||||
| def generate_ddp_file(trainer): | ||||
|     import_path = '.'.join(str(trainer.__class__).split(".")[1:-1]) | ||||
|  | ||||
|     shutil.rmtree(trainer.save_dir)  # remove the save_dir | ||||
|     content = f'''overrides = {dict(trainer.args)} \nif __name__ == "__main__": | ||||
|     if not trainer.resume: | ||||
|         shutil.rmtree(trainer.save_dir)  # remove the save_dir | ||||
|     content = f'''config = {dict(trainer.args)} \nif __name__ == "__main__": | ||||
|     from ultralytics.{import_path} import {trainer.__class__.__name__} | ||||
|  | ||||
|     trainer = {trainer.__class__.__name__}(overrides=overrides) | ||||
|     trainer = {trainer.__class__.__name__}(config=config) | ||||
|     trainer.train()''' | ||||
|     (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) | ||||
|     with tempfile.NamedTemporaryFile(prefix="_temp_", | ||||
|  | ||||
| @ -1,3 +1,5 @@ | ||||
| from copy import copy | ||||
|  | ||||
| import hydra | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| @ -64,7 +66,7 @@ class DetectionTrainer(BaseTrainer): | ||||
|         return v8.detect.DetectionValidator(self.test_loader, | ||||
|                                             save_dir=self.save_dir, | ||||
|                                             logger=self.console, | ||||
|                                             args=self.args) | ||||
|                                             args=copy(self.args)) | ||||
|  | ||||
|     def criterion(self, preds, batch): | ||||
|         if not hasattr(self, 'compute_loss'): | ||||
|  | ||||
| @ -42,10 +42,9 @@ class DetectionValidator(BaseValidator): | ||||
|  | ||||
|     def init_metrics(self, model): | ||||
|         head = model.model[-1] if self.training else model.model.model[-1] | ||||
|         if self.data: | ||||
|             self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt')  # is COCO dataset | ||||
|             self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) | ||||
|             self.args.save_json |= self.is_coco and not self.training  # run on final val if training COCO | ||||
|         self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt')  # is COCO dataset | ||||
|         self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) | ||||
|         self.args.save_json |= self.is_coco and not self.training  # run on final val if training COCO | ||||
|         self.nc = head.nc | ||||
|         self.names = model.names | ||||
|         self.metrics.names = self.names | ||||
|  | ||||
| @ -1,3 +1,5 @@ | ||||
| from copy import copy | ||||
|  | ||||
| import hydra | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| @ -27,7 +29,7 @@ class SegmentationTrainer(v8.detect.DetectionTrainer): | ||||
|         return v8.segment.SegmentationValidator(self.test_loader, | ||||
|                                                 save_dir=self.save_dir, | ||||
|                                                 logger=self.console, | ||||
|                                                 args=self.args) | ||||
|                                                 args=copy(self.args)) | ||||
|  | ||||
|     def criterion(self, preds, batch): | ||||
|         if not hasattr(self, 'compute_loss'): | ||||
|  | ||||
| @ -37,10 +37,9 @@ class SegmentationValidator(DetectionValidator): | ||||
|  | ||||
|     def init_metrics(self, model): | ||||
|         head = model.model[-1] if self.training else model.model.model[-1] | ||||
|         if self.data: | ||||
|             self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt')  # is COCO dataset | ||||
|             self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) | ||||
|             self.args.save_json |= self.is_coco and not self.training  # run on final val if training COCO | ||||
|         self.is_coco = self.data.get('val', '').endswith(f'coco{os.sep}val2017.txt')  # is COCO dataset | ||||
|         self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) | ||||
|         self.args.save_json |= self.is_coco and not self.training  # run on final val if training COCO | ||||
|         self.nc = head.nc | ||||
|         self.nm = head.nm if hasattr(head, "nm") else 32 | ||||
|         self.names = model.names | ||||
|  | ||||
		Reference in New Issue
	
	Block a user