|
|
|
@ -13,7 +13,6 @@ from pathlib import Path
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
from torch.cuda import amp
|
|
|
|
@ -111,8 +110,12 @@ class BaseTrainer:
|
|
|
|
|
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
|
|
|
|
command = generate_ddp_command(world_size, self)
|
|
|
|
|
print('DDP command: ', command)
|
|
|
|
|
subprocess.Popen(command)
|
|
|
|
|
# ddp_cleanup(command, self) # TODO: uncomment and fix
|
|
|
|
|
try:
|
|
|
|
|
subprocess.run(command)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
self.console(e)
|
|
|
|
|
finally:
|
|
|
|
|
ddp_cleanup(command, self)
|
|
|
|
|
else:
|
|
|
|
|
self._do_train(int(os.getenv("RANK", -1)), world_size)
|
|
|
|
|
|
|
|
|
@ -122,7 +125,6 @@ class BaseTrainer:
|
|
|
|
|
torch.cuda.set_device(rank)
|
|
|
|
|
self.device = torch.device('cuda', rank)
|
|
|
|
|
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
|
|
|
|
mp.set_start_method('spawn', force=True)
|
|
|
|
|
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
|
|
|
|
|
|
|
|
|
def _setup_train(self, rank, world_size):
|
|
|
|
@ -159,8 +161,8 @@ class BaseTrainer:
|
|
|
|
|
if rank in {0, -1}:
|
|
|
|
|
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
|
|
|
|
self.validator = self.get_validator()
|
|
|
|
|
# metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
|
|
|
|
# self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
|
|
|
|
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
|
|
|
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
|
|
|
|
self.ema = ModelEMA(self.model)
|
|
|
|
|
self.trigger_callbacks("on_pretrain_routine_end")
|
|
|
|
|
|
|
|
|
|