Update `generate_ddp_file` for improved `overrides` (#2909)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Laughing 2 years ago committed by GitHub
parent facb7861cf
commit 305cde69d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -182,7 +182,7 @@ class BaseTrainer:
# Command # Command
cmd, file = generate_ddp_command(world_size, self) cmd, file = generate_ddp_command(world_size, self)
try: try:
LOGGER.info(f'Running DDP command {cmd}') LOGGER.info(f'DDP command: {cmd}')
subprocess.run(cmd, check=True) subprocess.run(cmd, check=True)
except Exception as e: except Exception as e:
raise e raise e
@ -195,7 +195,7 @@ class BaseTrainer:
"""Initializes and sets the DistributedDataParallel parameters for training.""" """Initializes and sets the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK) torch.cuda.set_device(RANK)
self.device = torch.device('cuda', RANK) self.device = torch.device('cuda', RANK)
LOGGER.info(f'DDP settings: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo',
timeout=timedelta(seconds=3600), timeout=timedelta(seconds=3600),

@ -27,10 +27,13 @@ def generate_ddp_file(trainer):
"""Generates a DDP file and returns its file name.""" """Generates a DDP file and returns its file name."""
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1) module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__": content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__":
from {module} import {name} from {module} import {name}
from ultralytics.yolo.utils import DEFAULT_CFG_DICT
trainer = {name}(cfg=cfg) cfg = DEFAULT_CFG_DICT.copy()
cfg.update(save_dir='') # handle the extra key 'save_dir'
trainer = {name}(cfg=cfg, overrides=overrides)
trainer.train()''' trainer.train()'''
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(prefix='_temp_', with tempfile.NamedTemporaryFile(prefix='_temp_',
@ -54,9 +57,7 @@ def generate_ddp_command(world_size, trainer):
file = generate_ddp_file(trainer) file = generate_ddp_file(trainer)
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
port = find_free_network_port() port = find_free_network_port()
exclude_args = ['save_dir'] cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args
return cmd, file return cmd, file

@ -19,7 +19,7 @@ class ClassificationTrainer(BaseTrainer):
if overrides is None: if overrides is None:
overrides = {} overrides = {}
overrides['task'] = 'classify' overrides['task'] = 'classify'
if overrides.get('imgsz') is None and cfg['imgsz'] == DEFAULT_CFG.imgsz == 640: if overrides.get('imgsz') is None:
overrides['imgsz'] = 224 overrides['imgsz'] = 224
super().__init__(cfg, overrides, _callbacks) super().__init__(cfg, overrides, _callbacks)

Loading…
Cancel
Save