From dde89c744cdcbe0cf774775a43d5f503451be3fa Mon Sep 17 00:00:00 2001 From: Andy <39454881+yermandy@users.noreply.github.com> Date: Wed, 9 Aug 2023 19:03:46 +0200 Subject: [PATCH] Refactor val code into new `self.match_predictions()` method (#4265) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/engine/validator.py | 28 ++++++++++++++++++++++++++ ultralytics/models/fastsam/val.py | 15 +------------- ultralytics/models/yolo/detect/val.py | 15 +------------- ultralytics/models/yolo/pose/val.py | 15 +------------- ultralytics/models/yolo/segment/val.py | 15 +------------- 5 files changed, 32 insertions(+), 56 deletions(-) diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index 1551cc3..4dfe679 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -22,6 +22,7 @@ import json import time from pathlib import Path +import numpy as np import torch from tqdm import tqdm @@ -199,6 +200,33 @@ class BaseValidator: LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") return stats + def match_predictions(self, pred_classes: torch.Tensor, true_classes: torch.Tensor, + iou: torch.Tensor) -> torch.Tensor: + """ + Matches predictions to ground truth objects (pred_classes, true_classes) using IoU. + + Args: + pred_classes (torch.Tensor): Predicted class indices of shape(N,). + true_classes (torch.Tensor): Target class indices of shape(M,). + + Returns: + (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds. + """ + correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool) + correct_class = true_classes[:, None] == pred_classes + for i in range(len(self.iouv)): + x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match + if x[0].shape[0]: + # Concatenate [label, detect, iou] + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + correct[matches[:, 1].astype(int), i] = True + return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device) + def add_callback(self, event: str, callback): """Appends the given callback.""" self.callbacks[event].append(callback) diff --git a/ultralytics/models/fastsam/val.py b/ultralytics/models/fastsam/val.py index 9bbae57..f1366f9 100644 --- a/ultralytics/models/fastsam/val.py +++ b/ultralytics/models/fastsam/val.py @@ -150,20 +150,7 @@ class FastSAMValidator(DetectionValidator): else: # boxes iou = box_iou(labels[:, 1:], detections[:, :4]) - correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool) - correct_class = labels[:, 0:1] == detections[:, 5] - for i in range(len(self.iouv)): - x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match - if x[0].shape[0]: - matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), - 1).cpu().numpy() # [label, detect, iou] - if x[0].shape[0] > 1: - matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 1], return_index=True)[1]] - # matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 0], return_index=True)[1]] - correct[matches[:, 1].astype(int), i] = True - return torch.tensor(correct, dtype=torch.bool, device=detections.device) + return self.match_predictions(detections[:, 5], labels[:, 0], iou) def plot_val_samples(self, batch, ni): """Plots validation samples with bounding box labels.""" diff --git a/ultralytics/models/yolo/detect/val.py b/ultralytics/models/yolo/detect/val.py index d37aa2b..dd282b9 100644 --- a/ultralytics/models/yolo/detect/val.py +++ b/ultralytics/models/yolo/detect/val.py @@ -163,20 +163,7 @@ class DetectionValidator(BaseValidator): correct (array[N, 10]), for 10 IoU levels """ iou = box_iou(labels[:, 1:], detections[:, :4]) - correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool) - correct_class = labels[:, 0:1] == detections[:, 5] - for i in range(len(self.iouv)): - x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match - if x[0].shape[0]: - matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), - 1).cpu().numpy() # [label, detect, iou] - if x[0].shape[0] > 1: - matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 1], return_index=True)[1]] - # matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 0], return_index=True)[1]] - correct[matches[:, 1].astype(int), i] = True - return torch.tensor(correct, dtype=torch.bool, device=detections.device) + return self.match_predictions(detections[:, 5], labels[:, 0], iou) def build_dataset(self, img_path, mode='val', batch=None): """Build YOLO Dataset diff --git a/ultralytics/models/yolo/pose/val.py b/ultralytics/models/yolo/pose/val.py index 34390a3..cbf2668 100644 --- a/ultralytics/models/yolo/pose/val.py +++ b/ultralytics/models/yolo/pose/val.py @@ -128,20 +128,7 @@ class PoseValidator(DetectionValidator): else: # boxes iou = box_iou(labels[:, 1:], detections[:, :4]) - correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool) - correct_class = labels[:, 0:1] == detections[:, 5] - for i in range(len(self.iouv)): - x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match - if x[0].shape[0]: - matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), - 1).cpu().numpy() # [label, detect, iou] - if x[0].shape[0] > 1: - matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 1], return_index=True)[1]] - # matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 0], return_index=True)[1]] - correct[matches[:, 1].astype(int), i] = True - return torch.tensor(correct, dtype=torch.bool, device=detections.device) + return self.match_predictions(detections[:, 5], labels[:, 0], iou) def plot_val_samples(self, batch, ni): """Plots and saves validation set samples with predicted bounding boxes and keypoints.""" diff --git a/ultralytics/models/yolo/segment/val.py b/ultralytics/models/yolo/segment/val.py index 5735d3f..1a295cd 100644 --- a/ultralytics/models/yolo/segment/val.py +++ b/ultralytics/models/yolo/segment/val.py @@ -150,20 +150,7 @@ class SegmentationValidator(DetectionValidator): else: # boxes iou = box_iou(labels[:, 1:], detections[:, :4]) - correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool) - correct_class = labels[:, 0:1] == detections[:, 5] - for i in range(len(self.iouv)): - x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match - if x[0].shape[0]: - matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), - 1).cpu().numpy() # [label, detect, iou] - if x[0].shape[0] > 1: - matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 1], return_index=True)[1]] - # matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 0], return_index=True)[1]] - correct[matches[:, 1].astype(int), i] = True - return torch.tensor(correct, dtype=torch.bool, device=detections.device) + return self.match_predictions(detections[:, 5], labels[:, 0], iou) def plot_val_samples(self, batch, ni): """Plots validation samples with bounding box labels."""