From 305cde69d09aeae2884c27bbd9e557aa47ffa5b0 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Wed, 31 May 2023 01:41:44 +0800 Subject: [PATCH] Update `generate_ddp_file` for improved `overrides` (#2909) Co-authored-by: Glenn Jocher --- ultralytics/yolo/engine/trainer.py | 4 ++-- ultralytics/yolo/utils/dist.py | 11 ++++++----- ultralytics/yolo/v8/classify/train.py | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index c69a7b7..7966526 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -182,7 +182,7 @@ class BaseTrainer: # Command cmd, file = generate_ddp_command(world_size, self) try: - LOGGER.info(f'Running DDP command {cmd}') + LOGGER.info(f'DDP command: {cmd}') subprocess.run(cmd, check=True) except Exception as e: raise e @@ -195,7 +195,7 @@ class BaseTrainer: """Initializes and sets the DistributedDataParallel parameters for training.""" torch.cuda.set_device(RANK) self.device = torch.device('cuda', RANK) - LOGGER.info(f'DDP settings: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') + LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', timeout=timedelta(seconds=3600), diff --git a/ultralytics/yolo/utils/dist.py b/ultralytics/yolo/utils/dist.py index edd484a..6de029f 100644 --- a/ultralytics/yolo/utils/dist.py +++ b/ultralytics/yolo/utils/dist.py @@ -27,10 +27,13 @@ def generate_ddp_file(trainer): """Generates a DDP file and returns its file name.""" module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1) - content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__": + content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__": from {module} import {name} + from ultralytics.yolo.utils import DEFAULT_CFG_DICT - trainer = {name}(cfg=cfg) + cfg = DEFAULT_CFG_DICT.copy() + cfg.update(save_dir='') # handle the extra key 'save_dir' + trainer = {name}(cfg=cfg, overrides=overrides) trainer.train()''' (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) with tempfile.NamedTemporaryFile(prefix='_temp_', @@ -54,9 +57,7 @@ def generate_ddp_command(world_size, trainer): file = generate_ddp_file(trainer) dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' port = find_free_network_port() - exclude_args = ['save_dir'] - args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args] - cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args + cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] return cmd, file diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index cc752b8..2949644 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -19,7 +19,7 @@ class ClassificationTrainer(BaseTrainer): if overrides is None: overrides = {} overrides['task'] = 'classify' - if overrides.get('imgsz') is None and cfg['imgsz'] == DEFAULT_CFG.imgsz == 640: + if overrides.get('imgsz') is None: overrides['imgsz'] = 224 super().__init__(cfg, overrides, _callbacks)