DDP and new dataloader Fix (#95)

Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 16e3c08883
commit 4fb04be20b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,15 +51,15 @@ repos:
additional_dependencies: additional_dependencies:
- mdformat-gfm - mdformat-gfm
- mdformat-black - mdformat-black
exclude: "README.md|README_cn.md| CONTRIBUTING.md" exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md"
- repo: https://github.com/asottile/yesqa
rev: v1.4.0
hooks:
- id: yesqa
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 5.0.4 rev: 5.0.4
hooks: hooks:
- id: flake8 - id: flake8
name: PEP8 name: PEP8
#- repo: https://github.com/asottile/yesqa
# rev: v1.4.0
# hooks:
# - id: yesqa

@ -1,7 +1,7 @@
import torch import torch
import yaml import yaml
from ultralytics import yolo # (required for python usage) from ultralytics import yolo # noqa required for python usage
# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml # from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils import LOGGER

@ -13,7 +13,6 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from omegaconf import OmegaConf from omegaconf import OmegaConf
from torch.cuda import amp from torch.cuda import amp
@ -111,8 +110,12 @@ class BaseTrainer:
if world_size > 1 and "LOCAL_RANK" not in os.environ: if world_size > 1 and "LOCAL_RANK" not in os.environ:
command = generate_ddp_command(world_size, self) command = generate_ddp_command(world_size, self)
print('DDP command: ', command) print('DDP command: ', command)
subprocess.Popen(command) try:
# ddp_cleanup(command, self) # TODO: uncomment and fix subprocess.run(command)
except Exception as e:
self.console(e)
finally:
ddp_cleanup(command, self)
else: else:
self._do_train(int(os.getenv("RANK", -1)), world_size) self._do_train(int(os.getenv("RANK", -1)), world_size)
@ -122,7 +125,6 @@ class BaseTrainer:
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank) self.device = torch.device('cuda', rank)
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ") self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
mp.set_start_method('spawn', force=True)
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
def _setup_train(self, rank, world_size): def _setup_train(self, rank, world_size):
@ -159,8 +161,8 @@ class BaseTrainer:
if rank in {0, -1}: if rank in {0, -1}:
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val") self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
self.validator = self.get_validator() self.validator = self.get_validator()
# metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val") metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
# self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()? self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model) self.ema = ModelEMA(self.model)
self.trigger_callbacks("on_pretrain_routine_end") self.trigger_callbacks("on_pretrain_routine_end")

@ -3,7 +3,8 @@ import shutil
import socket import socket
import sys import sys
import tempfile import tempfile
import time
from . import USER_CONFIG_DIR
def find_free_network_port() -> int: def find_free_network_port() -> int:
@ -23,25 +24,25 @@ def find_free_network_port() -> int:
def generate_ddp_file(trainer): def generate_ddp_file(trainer):
import_path = '.'.join(str(trainer.__class__).split(".")[1:-1]) import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
# remove the save_dir shutil.rmtree(trainer.save_dir) # remove the save_dir
shutil.rmtree(trainer.save_dir)
content = f'''overrides = {dict(trainer.args)} \nif __name__ == "__main__": content = f'''overrides = {dict(trainer.args)} \nif __name__ == "__main__":
from ultralytics.{import_path} import {trainer.__class__.__name__} from ultralytics.{import_path} import {trainer.__class__.__name__}
trainer = {trainer.__class__.__name__}(overrides=overrides) trainer = {trainer.__class__.__name__}(overrides=overrides)
trainer.train()''' trainer.train()'''
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(prefix="_temp_", with tempfile.NamedTemporaryFile(prefix="_temp_",
suffix=f"{id(trainer)}.py", suffix=f"{id(trainer)}.py",
mode="w+", mode="w+",
encoding='utf-8', encoding='utf-8',
dir=os.path.curdir, dir=USER_CONFIG_DIR / 'DDP',
delete=False) as file: delete=False) as file:
file.write(content) file.write(content)
return file.name return file.name
def generate_ddp_command(world_size, trainer): def generate_ddp_command(world_size, trainer):
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
file_name = os.path.abspath(sys.argv[0]) file_name = os.path.abspath(sys.argv[0])
using_cli = not file_name.endswith(".py") using_cli = not file_name.endswith(".py")
if using_cli: if using_cli:
@ -53,8 +54,6 @@ def generate_ddp_command(world_size, trainer):
def ddp_cleanup(command, trainer): def ddp_cleanup(command, trainer):
# delete temp file if created # delete temp file if created
# TODO: this is a temp solution in case the file is deleted before DDP launching
time.sleep(5)
tempfile_suffix = f"{id(trainer)}.py" tempfile_suffix = f"{id(trainer)}.py"
if tempfile_suffix in "".join(command): if tempfile_suffix in "".join(command):
for chunk in command: for chunk in command:

@ -58,8 +58,6 @@ class DetectionTrainer(BaseTrainer):
model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"]) model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"])
if weights: if weights:
model.load(weights) model.load(weights)
for _, v in model.named_parameters():
v.requires_grad = True # train all layers
return model return model
def get_validator(self): def get_validator(self):

@ -21,8 +21,6 @@ class SegmentationTrainer(DetectionTrainer):
model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"]) model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"])
if weights: if weights:
model.load(weights) model.load(weights)
for _, v in model.named_parameters():
v.requires_grad = True # train all layers
return model return model
def get_validator(self): def get_validator(self):

Loading…
Cancel
Save