Update save_dir rank check (#114)

Co-authored-by: Laughing-q <1185102784@qq.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 401bc15345
commit c1b38428bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -49,9 +49,12 @@ class BaseTrainer:
# dirs # dirs
project = self.args.project or f"runs/{self.args.task}" project = self.args.project or f"runs/{self.args.task}"
name = self.args.name or f"{self.args.mode}" name = self.args.name or f"{self.args.mode}"
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK == -1 else True) self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
self.wdir = self.save_dir / 'weights' # weights dir self.wdir = self.save_dir / 'weights' # weights dir
self.wdir.mkdir(parents=True, exist_ok=True) # make dir if RANK in {-1, 0}:
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
# Save run settings
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.batch_size = self.args.batch_size self.batch_size = self.args.batch_size
@ -60,9 +63,6 @@ class BaseTrainer:
if RANK == -1: if RANK == -1:
print_args(dict(self.args)) print_args(dict(self.args))
# Save run settings
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
# device # device
self.device = utils.torch_utils.select_device(self.args.device, self.batch_size) self.device = utils.torch_utils.select_device(self.args.device, self.batch_size)
self.amp = self.device.type != 'cpu' self.amp = self.device.type != 'cpu'

@ -35,7 +35,7 @@ class BaseValidator:
project = self.args.project or f"runs/{self.args.task}" project = self.args.project or f"runs/{self.args.task}"
name = self.args.name or f"{self.args.mode}" name = self.args.name or f"{self.args.mode}"
self.save_dir = save_dir or increment_path(Path(project) / name, self.save_dir = save_dir or increment_path(Path(project) / name,
exist_ok=self.args.exist_ok if RANK == -1 else True) exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
@smart_inference_mode() @smart_inference_mode()

Loading…
Cancel
Save