Update generate_ddp_file
for improved overrides
(#2909)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -182,7 +182,7 @@ class BaseTrainer:
|
||||
# Command
|
||||
cmd, file = generate_ddp_command(world_size, self)
|
||||
try:
|
||||
LOGGER.info(f'Running DDP command {cmd}')
|
||||
LOGGER.info(f'DDP command: {cmd}')
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception as e:
|
||||
raise e
|
||||
@ -195,7 +195,7 @@ class BaseTrainer:
|
||||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||
torch.cuda.set_device(RANK)
|
||||
self.device = torch.device('cuda', RANK)
|
||||
LOGGER.info(f'DDP settings: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
|
||||
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo',
|
||||
timeout=timedelta(seconds=3600),
|
||||
|
@ -27,10 +27,13 @@ def generate_ddp_file(trainer):
|
||||
"""Generates a DDP file and returns its file name."""
|
||||
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
|
||||
|
||||
content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__":
|
||||
content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__":
|
||||
from {module} import {name}
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT
|
||||
|
||||
trainer = {name}(cfg=cfg)
|
||||
cfg = DEFAULT_CFG_DICT.copy()
|
||||
cfg.update(save_dir='') # handle the extra key 'save_dir'
|
||||
trainer = {name}(cfg=cfg, overrides=overrides)
|
||||
trainer.train()'''
|
||||
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(prefix='_temp_',
|
||||
@ -54,9 +57,7 @@ def generate_ddp_command(world_size, trainer):
|
||||
file = generate_ddp_file(trainer)
|
||||
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
||||
port = find_free_network_port()
|
||||
exclude_args = ['save_dir']
|
||||
args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
|
||||
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args
|
||||
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
|
||||
return cmd, file
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides['task'] = 'classify'
|
||||
if overrides.get('imgsz') is None and cfg['imgsz'] == DEFAULT_CFG.imgsz == 640:
|
||||
if overrides.get('imgsz') is None:
|
||||
overrides['imgsz'] = 224
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
|
Reference in New Issue
Block a user