DDP and new dataloader Fix (#95)
Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from ultralytics import yolo # (required for python usage)
|
||||
from ultralytics import yolo # noqa required for python usage
|
||||
# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
|
@ -13,7 +13,6 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
from torch.cuda import amp
|
||||
@ -111,8 +110,12 @@ class BaseTrainer:
|
||||
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
||||
command = generate_ddp_command(world_size, self)
|
||||
print('DDP command: ', command)
|
||||
subprocess.Popen(command)
|
||||
# ddp_cleanup(command, self) # TODO: uncomment and fix
|
||||
try:
|
||||
subprocess.run(command)
|
||||
except Exception as e:
|
||||
self.console(e)
|
||||
finally:
|
||||
ddp_cleanup(command, self)
|
||||
else:
|
||||
self._do_train(int(os.getenv("RANK", -1)), world_size)
|
||||
|
||||
@ -122,7 +125,6 @@ class BaseTrainer:
|
||||
torch.cuda.set_device(rank)
|
||||
self.device = torch.device('cuda', rank)
|
||||
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
||||
mp.set_start_method('spawn', force=True)
|
||||
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
||||
|
||||
def _setup_train(self, rank, world_size):
|
||||
@ -159,8 +161,8 @@ class BaseTrainer:
|
||||
if rank in {0, -1}:
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
||||
self.validator = self.get_validator()
|
||||
# metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
||||
# self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||
self.ema = ModelEMA(self.model)
|
||||
self.trigger_callbacks("on_pretrain_routine_end")
|
||||
|
||||
|
Reference in New Issue
Block a user