diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 31de0e0..53e376c 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -3,6 +3,8 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo """ import os +import subprocess +import sys import time from collections import defaultdict from copy import deepcopy @@ -26,6 +28,7 @@ from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr from ultralytics.yolo.utils.checks import check_file, print_args from ultralytics.yolo.utils.configs import get_config +from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer @@ -103,15 +106,16 @@ class BaseTrainer: def train(self): world_size = torch.cuda.device_count() - if world_size > 1: - mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True) + if world_size > 1 and not ("LOCAL_RANK" in os.environ): + command = generate_ddp_command(world_size, self) + subprocess.Popen(command) + ddp_cleanup(command, self) else: - # self._do_train(int(os.getenv("RANK", -1)), world_size) - self._do_train() + self._do_train(int(os.getenv("RANK", -1)), world_size) def _setup_ddp(self, rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '9020' + # os.environ['MASTER_ADDR'] = 'localhost' + # os.environ['MASTER_PORT'] = '9020' torch.cuda.set_device(rank) self.device = torch.device('cuda', rank) self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ") @@ -146,7 +150,7 @@ class BaseTrainer: self.scheduler.last_epoch = self.start_epoch - 1 # do not move # dataloaders - batch_size = self.batch_size // world_size + batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train") if rank in {0, -1}: self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val") @@ -258,7 +262,7 @@ class BaseTrainer: self.plot_metrics() self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)") self.trigger_callbacks('on_train_end') - dist.destroy_process_group() if world_size != 1 else None + dist.destroy_process_group() if world_size > 1 else None torch.cuda.empty_cache() def save_model(self): diff --git a/ultralytics/yolo/utils/dist.py b/ultralytics/yolo/utils/dist.py new file mode 100644 index 0000000..ae4f2a5 --- /dev/null +++ b/ultralytics/yolo/utils/dist.py @@ -0,0 +1,63 @@ +import os +import shutil +import socket +import sys +import tempfile +import time + + +def find_free_network_port() -> int: + # https://github.com/Lightning-AI/lightning/blob/master/src/lightning_lite/plugins/environments/lightning.py + """Finds a free port on localhost. + + It is useful in single-node training when we don't want to connect to a real main node but have to set the + `MASTER_PORT` environment variable. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port + + +def generate_ddp_file(trainer): + import_path = '.'.join(str(trainer.__class__).split(".")[1:-1]) + + # remove the save_dir + shutil.rmtree(trainer.save_dir) + content = f'''overrides = {dict(trainer.args)} \nif __name__ == "__main__": + from ultralytics.{import_path} import {trainer.__class__.__name__} + + trainer = {trainer.__class__.__name__}(overrides=overrides) + trainer.train()''' + with tempfile.NamedTemporaryFile(prefix="_temp_", + suffix=f"{id(trainer)}.py", + mode="w+", + encoding='utf-8', + dir=os.path.curdir, + delete=False) as file: + file.write(content) + return file.name + + +def generate_ddp_command(world_size, trainer): + import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 + file_name = os.path.abspath(sys.argv[0]) + using_cli = not file_name.endswith(".py") + if using_cli: + file_name = generate_ddp_file(trainer) + return [ + sys.executable, "-m", "torch.distributed.launch", "--nproc_per_node", f"{world_size}", "--master_port", + f"{find_free_network_port()}", file_name] + sys.argv[1:] + + +def ddp_cleanup(command, trainer): + # delete temp file if created + # TODO: this is a temp solution in case the file is deleted before DDP launching + time.sleep(5) + tempfile_suffix = str(id(trainer)) + ".py" + if tempfile_suffix in "".join(command): + for chunk in command: + if tempfile_suffix in chunk: + os.remove(chunk) + break diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index 546219d..3ab71f6 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -25,11 +25,8 @@ class DetectionValidator(BaseValidator): self.class_map = None self.targets = None self.metrics = DetMetrics(save_dir=self.save_dir, plot=self.args.plots) - self.iouv = torch.linspace(0.5, 0.95, 10, device=self.device) # iou vector for mAP@0.5:0.95 + self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95 self.niou = self.iouv.numel() - self.seen = 0 - self.jdict = [] - self.stats = [] def preprocess(self, batch): batch["img"] = batch["img"].to(self.device, non_blocking=True) @@ -56,6 +53,9 @@ class DetectionValidator(BaseValidator): self.names = dict(enumerate(self.names)) self.metrics.names = self.names self.confusion_matrix = ConfusionMatrix(nc=self.nc) + self.seen = 0 + self.jdict = [] + self.stats = [] def get_desc(self): return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)") @@ -98,7 +98,7 @@ class DetectionValidator(BaseValidator): tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels - correct_bboxes = self._process_batch(predn, labelsn, self.iouv) + correct_bboxes = self._process_batch(predn, labelsn) # TODO: maybe remove these `self.` arguments as they already are member variable if self.args.plots: self.confusion_matrix.process_batch(predn, labelsn) @@ -139,7 +139,7 @@ class DetectionValidator(BaseValidator): if self.args.plots: self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values())) - def _process_batch(self, detections, labels, iouv): + def _process_batch(self, detections, labels): """ Return correct prediction matrix Arguments: @@ -149,10 +149,10 @@ class DetectionValidator(BaseValidator): correct (array[N, 10]), for 10 IoU levels """ iou = box_iou(labels[:, 1:], detections[:, :4]) - correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool) + correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool) correct_class = labels[:, 0:1] == detections[:, 5] - for i in range(len(iouv)): - x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match + for i in range(len(self.iouv)): + x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match if x[0].shape[0]: matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou] @@ -162,7 +162,7 @@ class DetectionValidator(BaseValidator): # matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[np.unique(matches[:, 0], return_index=True)[1]] correct[matches[:, 1].astype(int), i] = True - return torch.tensor(correct, dtype=torch.bool, device=iouv.device) + return torch.tensor(correct, dtype=torch.bool, device=detections.device) def get_dataloader(self, dataset_path, batch_size): # TODO: manage splits differently diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 6c0082c..9f77707 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -5,13 +5,11 @@ import numpy as np import torch import torch.nn.functional as F -from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.utils import ops from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou from ultralytics.yolo.utils.plotting import output_to_target, plot_images -from ultralytics.yolo.utils.torch_utils import de_parallel from ..detect import DetectionValidator @@ -55,6 +53,9 @@ class SegmentationValidator(DetectionValidator): self.metrics.names = self.names self.confusion_matrix = ConfusionMatrix(nc=self.nc) self.plot_masks = [] + self.seen = 0 + self.jdict = [] + self.stats = [] def get_desc(self): return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", @@ -106,11 +107,10 @@ class SegmentationValidator(DetectionValidator): tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels - correct_bboxes = self._process_batch(predn, labelsn, self.iouv) + correct_bboxes = self._process_batch(predn, labelsn) # TODO: maybe remove these `self.` arguments as they already are member variable correct_masks = self._process_batch(predn, labelsn, - self.iouv, pred_masks, gt_masks, overlap=self.args.overlap_mask, @@ -135,7 +135,7 @@ class SegmentationValidator(DetectionValidator): # callbacks.run('on_val_image_end', pred, predn, path, names, im[si]) ''' - def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False): + def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False): """ Return correct prediction matrix Arguments: @@ -157,10 +157,10 @@ class SegmentationValidator(DetectionValidator): else: # boxes iou = box_iou(labels[:, 1:], detections[:, :4]) - correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool) + correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool) correct_class = labels[:, 0:1] == detections[:, 5] - for i in range(len(iouv)): - x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match + for i in range(len(self.iouv)): + x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match if x[0].shape[0]: matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou] @@ -170,7 +170,7 @@ class SegmentationValidator(DetectionValidator): # matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[np.unique(matches[:, 0], return_index=True)[1]] correct[matches[:, 1].astype(int), i] = True - return torch.tensor(correct, dtype=torch.bool, device=iouv.device) + return torch.tensor(correct, dtype=torch.bool, device=detections.device) # TODO: probably add this to class Metrics @property