add a naive DDP for model interface (#78)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
single_channel
Laughing 2 years ago committed by GitHub
parent 48c95ba083
commit 7690cae2fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,6 +3,8 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
""" """
import os import os
import subprocess
import sys
import time import time
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from copy import deepcopy
@ -26,6 +28,7 @@ from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
from ultralytics.yolo.utils.checks import check_file, print_args from ultralytics.yolo.utils.checks import check_file, print_args
from ultralytics.yolo.utils.configs import get_config from ultralytics.yolo.utils.configs import get_config
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
@ -103,15 +106,16 @@ class BaseTrainer:
def train(self): def train(self):
world_size = torch.cuda.device_count() world_size = torch.cuda.device_count()
if world_size > 1: if world_size > 1 and not ("LOCAL_RANK" in os.environ):
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True) command = generate_ddp_command(world_size, self)
subprocess.Popen(command)
ddp_cleanup(command, self)
else: else:
# self._do_train(int(os.getenv("RANK", -1)), world_size) self._do_train(int(os.getenv("RANK", -1)), world_size)
self._do_train()
def _setup_ddp(self, rank, world_size): def _setup_ddp(self, rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost' # os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '9020' # os.environ['MASTER_PORT'] = '9020'
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank) self.device = torch.device('cuda', rank)
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ") self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
@ -146,7 +150,7 @@ class BaseTrainer:
self.scheduler.last_epoch = self.start_epoch - 1 # do not move self.scheduler.last_epoch = self.start_epoch - 1 # do not move
# dataloaders # dataloaders
batch_size = self.batch_size // world_size batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train") self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
if rank in {0, -1}: if rank in {0, -1}:
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val") self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
@ -258,7 +262,7 @@ class BaseTrainer:
self.plot_metrics() self.plot_metrics()
self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)") self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)")
self.trigger_callbacks('on_train_end') self.trigger_callbacks('on_train_end')
dist.destroy_process_group() if world_size != 1 else None dist.destroy_process_group() if world_size > 1 else None
torch.cuda.empty_cache() torch.cuda.empty_cache()
def save_model(self): def save_model(self):

@ -0,0 +1,63 @@
import os
import shutil
import socket
import sys
import tempfile
import time
def find_free_network_port() -> int:
# https://github.com/Lightning-AI/lightning/blob/master/src/lightning_lite/plugins/environments/lightning.py
"""Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
return port
def generate_ddp_file(trainer):
import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
# remove the save_dir
shutil.rmtree(trainer.save_dir)
content = f'''overrides = {dict(trainer.args)} \nif __name__ == "__main__":
from ultralytics.{import_path} import {trainer.__class__.__name__}
trainer = {trainer.__class__.__name__}(overrides=overrides)
trainer.train()'''
with tempfile.NamedTemporaryFile(prefix="_temp_",
suffix=f"{id(trainer)}.py",
mode="w+",
encoding='utf-8',
dir=os.path.curdir,
delete=False) as file:
file.write(content)
return file.name
def generate_ddp_command(world_size, trainer):
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
file_name = os.path.abspath(sys.argv[0])
using_cli = not file_name.endswith(".py")
if using_cli:
file_name = generate_ddp_file(trainer)
return [
sys.executable, "-m", "torch.distributed.launch", "--nproc_per_node", f"{world_size}", "--master_port",
f"{find_free_network_port()}", file_name] + sys.argv[1:]
def ddp_cleanup(command, trainer):
# delete temp file if created
# TODO: this is a temp solution in case the file is deleted before DDP launching
time.sleep(5)
tempfile_suffix = str(id(trainer)) + ".py"
if tempfile_suffix in "".join(command):
for chunk in command:
if tempfile_suffix in chunk:
os.remove(chunk)
break

@ -25,11 +25,8 @@ class DetectionValidator(BaseValidator):
self.class_map = None self.class_map = None
self.targets = None self.targets = None
self.metrics = DetMetrics(save_dir=self.save_dir, plot=self.args.plots) self.metrics = DetMetrics(save_dir=self.save_dir, plot=self.args.plots)
self.iouv = torch.linspace(0.5, 0.95, 10, device=self.device) # iou vector for mAP@0.5:0.95 self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
self.niou = self.iouv.numel() self.niou = self.iouv.numel()
self.seen = 0
self.jdict = []
self.stats = []
def preprocess(self, batch): def preprocess(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True) batch["img"] = batch["img"].to(self.device, non_blocking=True)
@ -56,6 +53,9 @@ class DetectionValidator(BaseValidator):
self.names = dict(enumerate(self.names)) self.names = dict(enumerate(self.names))
self.metrics.names = self.names self.metrics.names = self.names
self.confusion_matrix = ConfusionMatrix(nc=self.nc) self.confusion_matrix = ConfusionMatrix(nc=self.nc)
self.seen = 0
self.jdict = []
self.stats = []
def get_desc(self): def get_desc(self):
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)") return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)")
@ -98,7 +98,7 @@ class DetectionValidator(BaseValidator):
tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn, self.iouv) correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable # TODO: maybe remove these `self.` arguments as they already are member variable
if self.args.plots: if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn) self.confusion_matrix.process_batch(predn, labelsn)
@ -139,7 +139,7 @@ class DetectionValidator(BaseValidator):
if self.args.plots: if self.args.plots:
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values())) self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
def _process_batch(self, detections, labels, iouv): def _process_batch(self, detections, labels):
""" """
Return correct prediction matrix Return correct prediction matrix
Arguments: Arguments:
@ -149,10 +149,10 @@ class DetectionValidator(BaseValidator):
correct (array[N, 10]), for 10 IoU levels correct (array[N, 10]), for 10 IoU levels
""" """
iou = box_iou(labels[:, 1:], detections[:, :4]) iou = box_iou(labels[:, 1:], detections[:, :4])
correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool) correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
correct_class = labels[:, 0:1] == detections[:, 5] correct_class = labels[:, 0:1] == detections[:, 5]
for i in range(len(iouv)): for i in range(len(self.iouv)):
x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]: if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
1).cpu().numpy() # [label, detect, iou] 1).cpu().numpy() # [label, detect, iou]
@ -162,7 +162,7 @@ class DetectionValidator(BaseValidator):
# matches = matches[matches[:, 2].argsort()[::-1]] # matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]] matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=iouv.device) return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def get_dataloader(self, dataset_path, batch_size): def get_dataloader(self, dataset_path, batch_size):
# TODO: manage splits differently # TODO: manage splits differently

@ -5,13 +5,11 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images from ultralytics.yolo.utils.plotting import output_to_target, plot_images
from ultralytics.yolo.utils.torch_utils import de_parallel
from ..detect import DetectionValidator from ..detect import DetectionValidator
@ -55,6 +53,9 @@ class SegmentationValidator(DetectionValidator):
self.metrics.names = self.names self.metrics.names = self.names
self.confusion_matrix = ConfusionMatrix(nc=self.nc) self.confusion_matrix = ConfusionMatrix(nc=self.nc)
self.plot_masks = [] self.plot_masks = []
self.seen = 0
self.jdict = []
self.stats = []
def get_desc(self): def get_desc(self):
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
@ -106,11 +107,10 @@ class SegmentationValidator(DetectionValidator):
tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn, self.iouv) correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable # TODO: maybe remove these `self.` arguments as they already are member variable
correct_masks = self._process_batch(predn, correct_masks = self._process_batch(predn,
labelsn, labelsn,
self.iouv,
pred_masks, pred_masks,
gt_masks, gt_masks,
overlap=self.args.overlap_mask, overlap=self.args.overlap_mask,
@ -135,7 +135,7 @@ class SegmentationValidator(DetectionValidator):
# callbacks.run('on_val_image_end', pred, predn, path, names, im[si]) # callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
''' '''
def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False): def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
""" """
Return correct prediction matrix Return correct prediction matrix
Arguments: Arguments:
@ -157,10 +157,10 @@ class SegmentationValidator(DetectionValidator):
else: # boxes else: # boxes
iou = box_iou(labels[:, 1:], detections[:, :4]) iou = box_iou(labels[:, 1:], detections[:, :4])
correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool) correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
correct_class = labels[:, 0:1] == detections[:, 5] correct_class = labels[:, 0:1] == detections[:, 5]
for i in range(len(iouv)): for i in range(len(self.iouv)):
x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]: if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
1).cpu().numpy() # [label, detect, iou] 1).cpu().numpy() # [label, detect, iou]
@ -170,7 +170,7 @@ class SegmentationValidator(DetectionValidator):
# matches = matches[matches[:, 2].argsort()[::-1]] # matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]] matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=iouv.device) return torch.tensor(correct, dtype=torch.bool, device=detections.device)
# TODO: probably add this to class Metrics # TODO: probably add this to class Metrics
@property @property

Loading…
Cancel
Save