DDP and new dataloader Fix (#95)

Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-12-26 20:05:49 +05:30
committed by GitHub
parent 16e3c08883
commit 4fb04be20b
6 changed files with 22 additions and 25 deletions

View File

@ -1,7 +1,7 @@
import torch
import yaml
from ultralytics import yolo # (required for python usage)
from ultralytics import yolo # noqa required for python usage
# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils import LOGGER

View File

@ -13,7 +13,6 @@ from pathlib import Path
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from omegaconf import OmegaConf
from torch.cuda import amp
@ -111,8 +110,12 @@ class BaseTrainer:
if world_size > 1 and "LOCAL_RANK" not in os.environ:
command = generate_ddp_command(world_size, self)
print('DDP command: ', command)
subprocess.Popen(command)
# ddp_cleanup(command, self) # TODO: uncomment and fix
try:
subprocess.run(command)
except Exception as e:
self.console(e)
finally:
ddp_cleanup(command, self)
else:
self._do_train(int(os.getenv("RANK", -1)), world_size)
@ -122,7 +125,6 @@ class BaseTrainer:
torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank)
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
mp.set_start_method('spawn', force=True)
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
def _setup_train(self, rank, world_size):
@ -159,8 +161,8 @@ class BaseTrainer:
if rank in {0, -1}:
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
self.validator = self.get_validator()
# metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
# self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model)
self.trigger_callbacks("on_pretrain_routine_end")