DDP and new dataloader Fix (#95)

Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 16e3c08883
commit 4fb04be20b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,15 +51,15 @@ repos:
additional_dependencies:
- mdformat-gfm
- mdformat-black
exclude: "README.md|README_cn.md| CONTRIBUTING.md"
- repo: https://github.com/asottile/yesqa
rev: v1.4.0
hooks:
- id: yesqa
exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md"
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
hooks:
- id: flake8
name: PEP8
#- repo: https://github.com/asottile/yesqa
# rev: v1.4.0
# hooks:
# - id: yesqa

@ -1,7 +1,7 @@
import torch
import yaml
from ultralytics import yolo # (required for python usage)
from ultralytics import yolo # noqa required for python usage
# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils import LOGGER

@ -13,7 +13,6 @@ from pathlib import Path
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from omegaconf import OmegaConf
from torch.cuda import amp
@ -111,8 +110,12 @@ class BaseTrainer:
if world_size > 1 and "LOCAL_RANK" not in os.environ:
command = generate_ddp_command(world_size, self)
print('DDP command: ', command)
subprocess.Popen(command)
# ddp_cleanup(command, self) # TODO: uncomment and fix
try:
subprocess.run(command)
except Exception as e:
self.console(e)
finally:
ddp_cleanup(command, self)
else:
self._do_train(int(os.getenv("RANK", -1)), world_size)
@ -122,7 +125,6 @@ class BaseTrainer:
torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank)
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
mp.set_start_method('spawn', force=True)
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
def _setup_train(self, rank, world_size):
@ -159,8 +161,8 @@ class BaseTrainer:
if rank in {0, -1}:
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
self.validator = self.get_validator()
# metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
# self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model)
self.trigger_callbacks("on_pretrain_routine_end")

@ -3,7 +3,8 @@ import shutil
import socket
import sys
import tempfile
import time
from . import USER_CONFIG_DIR
def find_free_network_port() -> int:
@ -23,25 +24,25 @@ def find_free_network_port() -> int:
def generate_ddp_file(trainer):
import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
# remove the save_dir
shutil.rmtree(trainer.save_dir)
shutil.rmtree(trainer.save_dir) # remove the save_dir
content = f'''overrides = {dict(trainer.args)} \nif __name__ == "__main__":
from ultralytics.{import_path} import {trainer.__class__.__name__}
trainer = {trainer.__class__.__name__}(overrides=overrides)
trainer.train()'''
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(prefix="_temp_",
suffix=f"{id(trainer)}.py",
mode="w+",
encoding='utf-8',
dir=os.path.curdir,
dir=USER_CONFIG_DIR / 'DDP',
delete=False) as file:
file.write(content)
return file.name
def generate_ddp_command(world_size, trainer):
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
file_name = os.path.abspath(sys.argv[0])
using_cli = not file_name.endswith(".py")
if using_cli:
@ -52,9 +53,7 @@ def generate_ddp_command(world_size, trainer):
def ddp_cleanup(command, trainer):
# delete temp file if created
# TODO: this is a temp solution in case the file is deleted before DDP launching
time.sleep(5)
# delete temp file if created
tempfile_suffix = f"{id(trainer)}.py"
if tempfile_suffix in "".join(command):
for chunk in command:

@ -58,8 +58,6 @@ class DetectionTrainer(BaseTrainer):
model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"])
if weights:
model.load(weights)
for _, v in model.named_parameters():
v.requires_grad = True # train all layers
return model
def get_validator(self):

@ -21,8 +21,6 @@ class SegmentationTrainer(DetectionTrainer):
model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"])
if weights:
model.load(weights)
for _, v in model.named_parameters():
v.requires_grad = True # train all layers
return model
def get_validator(self):

Loading…
Cancel
Save