Update generate_ddp_file for improved overrides (#2909)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing
2023-05-31 01:41:44 +08:00
committed by GitHub
parent facb7861cf
commit 305cde69d0
3 changed files with 9 additions and 8 deletions

View File

@ -27,10 +27,13 @@ def generate_ddp_file(trainer):
"""Generates a DDP file and returns its file name."""
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__":
content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__":
from {module} import {name}
from ultralytics.yolo.utils import DEFAULT_CFG_DICT
trainer = {name}(cfg=cfg)
cfg = DEFAULT_CFG_DICT.copy()
cfg.update(save_dir='') # handle the extra key 'save_dir'
trainer = {name}(cfg=cfg, overrides=overrides)
trainer.train()'''
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(prefix='_temp_',
@ -54,9 +57,7 @@ def generate_ddp_command(world_size, trainer):
file = generate_ddp_file(trainer)
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
port = find_free_network_port()
exclude_args = ['save_dir']
args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
return cmd, file