From c1b38428bc2ba9d6a87928a20468d5ec59a207f0 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 29 Dec 2022 17:32:01 +0530 Subject: [PATCH] Update save_dir rank check (#114) Co-authored-by: Laughing-q <1185102784@qq.com> --- ultralytics/yolo/engine/trainer.py | 10 +++++----- ultralytics/yolo/engine/validator.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index c437e83..71c2e20 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -49,9 +49,12 @@ class BaseTrainer: # dirs project = self.args.project or f"runs/{self.args.task}" name = self.args.name or f"{self.args.mode}" - self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK == -1 else True) + self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True) self.wdir = self.save_dir / 'weights' # weights dir - self.wdir.mkdir(parents=True, exist_ok=True) # make dir + if RANK in {-1, 0}: + self.wdir.mkdir(parents=True, exist_ok=True) # make dir + # Save run settings + save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths self.batch_size = self.args.batch_size @@ -60,9 +63,6 @@ class BaseTrainer: if RANK == -1: print_args(dict(self.args)) - # Save run settings - save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) - # device self.device = utils.torch_utils.select_device(self.args.device, self.batch_size) self.amp = self.device.type != 'cpu' diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index ecbd273..d0b001a 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -35,7 +35,7 @@ class BaseValidator: project = self.args.project or f"runs/{self.args.task}" name = self.args.name or f"{self.args.mode}" self.save_dir = save_dir or increment_path(Path(project) / name, - exist_ok=self.args.exist_ok if RANK == -1 else True) + exist_ok=self.args.exist_ok if RANK in {-1, 0} else True) (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) @smart_inference_mode()