|
|
@ -43,16 +43,18 @@ def bbox_ioa(box1, box2, eps=1e-7):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def box_iou(box1, box2, eps=1e-7):
|
|
|
|
def box_iou(box1, box2, eps=1e-7):
|
|
|
|
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Return intersection-over-union (Jaccard index) of boxes.
|
|
|
|
Return intersection-over-union (Jaccard index) of boxes.
|
|
|
|
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
|
|
|
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
|
|
|
|
|
|
|
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
|
|
|
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
Arguments:
|
|
|
|
box1 (Tensor[N, 4])
|
|
|
|
box1 (Tensor[N, 4])
|
|
|
|
box2 (Tensor[M, 4])
|
|
|
|
box2 (Tensor[M, 4])
|
|
|
|
|
|
|
|
eps
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
|
|
|
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
|
|
|
|
IoU values for every element in boxes1 and boxes2
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
|
|
|
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
|
|
@ -109,7 +111,7 @@ def mask_iou(mask1, mask2, eps=1e-7):
|
|
|
|
mask1: [N, n] m1 means number of predicted objects
|
|
|
|
mask1: [N, n] m1 means number of predicted objects
|
|
|
|
mask2: [M, n] m2 means number of gt objects
|
|
|
|
mask2: [M, n] m2 means number of gt objects
|
|
|
|
Note: n means image_w x image_h
|
|
|
|
Note: n means image_w x image_h
|
|
|
|
return: masks iou, [N, M]
|
|
|
|
Returns: masks iou, [N, M]
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
intersection = torch.matmul(mask1, mask2.t()).clamp(0)
|
|
|
|
intersection = torch.matmul(mask1, mask2.t()).clamp(0)
|
|
|
|
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
|
|
|
|
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
|
|
|
@ -121,7 +123,7 @@ def masks_iou(mask1, mask2, eps=1e-7):
|
|
|
|
mask1: [N, n] m1 means number of predicted objects
|
|
|
|
mask1: [N, n] m1 means number of predicted objects
|
|
|
|
mask2: [N, n] m2 means number of gt objects
|
|
|
|
mask2: [N, n] m2 means number of gt objects
|
|
|
|
Note: n means image_w x image_h
|
|
|
|
Note: n means image_w x image_h
|
|
|
|
return: masks iou, (N, )
|
|
|
|
Returns: masks iou, (N, )
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
intersection = (mask1 * mask2).sum(1).clamp(0) # (N, )
|
|
|
|
intersection = (mask1 * mask2).sum(1).clamp(0) # (N, )
|
|
|
|
union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection
|
|
|
|
union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection
|
|
|
@ -317,10 +319,10 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
|
|
|
|
|
|
|
|
|
|
|
|
def compute_ap(recall, precision):
|
|
|
|
def compute_ap(recall, precision):
|
|
|
|
""" Compute the average precision, given the recall and precision curves
|
|
|
|
""" Compute the average precision, given the recall and precision curves
|
|
|
|
# Arguments
|
|
|
|
Arguments:
|
|
|
|
recall: The recall curve (list)
|
|
|
|
recall: The recall curve (list)
|
|
|
|
precision: The precision curve (list)
|
|
|
|
precision: The precision curve (list)
|
|
|
|
# Returns
|
|
|
|
Returns:
|
|
|
|
Average precision, precision curve, recall curve
|
|
|
|
Average precision, precision curve, recall curve
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
@ -344,17 +346,30 @@ def compute_ap(recall, precision):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''):
|
|
|
|
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''):
|
|
|
|
""" Compute the average precision, given the recall and precision curves.
|
|
|
|
"""
|
|
|
|
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
|
|
|
Computes the average precision per class for object detection evaluation.
|
|
|
|
# Arguments
|
|
|
|
|
|
|
|
tp: True positives (nparray, nx1 or nx10).
|
|
|
|
Args:
|
|
|
|
conf: Objectness value from 0-1 (nparray).
|
|
|
|
tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
|
|
|
|
pred_cls: Predicted object classes (nparray).
|
|
|
|
conf (np.ndarray): Array of confidence scores of the detections.
|
|
|
|
target_cls: True object classes (nparray).
|
|
|
|
pred_cls (np.ndarray): Array of predicted classes of the detections.
|
|
|
|
plot: Plot precision-recall curve at mAP@0.5
|
|
|
|
target_cls (np.ndarray): Array of true classes of the detections.
|
|
|
|
save_dir: Plot save directory
|
|
|
|
plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
|
|
|
|
# Returns
|
|
|
|
save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
|
|
|
|
The average precision as computed in py-faster-rcnn.
|
|
|
|
names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple.
|
|
|
|
|
|
|
|
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
|
|
|
|
|
|
|
|
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
(tuple): A tuple of six arrays and one array of unique classes, where:
|
|
|
|
|
|
|
|
tp (np.ndarray): True positive counts for each class.
|
|
|
|
|
|
|
|
fp (np.ndarray): False positive counts for each class.
|
|
|
|
|
|
|
|
p (np.ndarray): Precision values at each confidence threshold.
|
|
|
|
|
|
|
|
r (np.ndarray): Recall values at each confidence threshold.
|
|
|
|
|
|
|
|
f1 (np.ndarray): F1-score values at each confidence threshold.
|
|
|
|
|
|
|
|
ap (np.ndarray): Average precision for each class at different IoU thresholds.
|
|
|
|
|
|
|
|
unique_classes (np.ndarray): An array of unique classes that have data.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# Sort by objectness
|
|
|
|
# Sort by objectness
|
|
|
@ -411,6 +426,32 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Metric:
|
|
|
|
class Metric:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Class for computing evaluation metrics for YOLOv8 model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
|
|
|
p (list): Precision for each class. Shape: (nc,).
|
|
|
|
|
|
|
|
r (list): Recall for each class. Shape: (nc,).
|
|
|
|
|
|
|
|
f1 (list): F1 score for each class. Shape: (nc,).
|
|
|
|
|
|
|
|
all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
|
|
|
|
|
|
|
|
ap_class_index (list): Index of class for each AP score. Shape: (nc,).
|
|
|
|
|
|
|
|
nc (int): Number of classes.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Methods:
|
|
|
|
|
|
|
|
ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
|
|
|
|
|
|
|
|
ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
|
|
|
|
|
|
|
|
mp(): Mean precision of all classes. Returns: Float.
|
|
|
|
|
|
|
|
mr(): Mean recall of all classes. Returns: Float.
|
|
|
|
|
|
|
|
map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.
|
|
|
|
|
|
|
|
map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.
|
|
|
|
|
|
|
|
map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.
|
|
|
|
|
|
|
|
mean_results(): Mean of results, returns mp, mr, map50, map.
|
|
|
|
|
|
|
|
class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].
|
|
|
|
|
|
|
|
maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
|
|
|
|
|
|
|
|
fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
|
|
|
|
|
|
|
|
update(results): Update metric attributes with new evaluation results.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self.p = [] # (nc, )
|
|
|
|
self.p = [] # (nc, )
|
|
|
@ -420,10 +461,14 @@ class Metric:
|
|
|
|
self.ap_class_index = [] # (nc, )
|
|
|
|
self.ap_class_index = [] # (nc, )
|
|
|
|
self.nc = 0
|
|
|
|
self.nc = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, attr):
|
|
|
|
|
|
|
|
name = self.__class__.__name__
|
|
|
|
|
|
|
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def ap50(self):
|
|
|
|
def ap50(self):
|
|
|
|
"""AP@0.5 of all classes.
|
|
|
|
"""AP@0.5 of all classes.
|
|
|
|
Return:
|
|
|
|
Returns:
|
|
|
|
(nc, ) or [].
|
|
|
|
(nc, ) or [].
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self.all_ap[:, 0] if len(self.all_ap) else []
|
|
|
|
return self.all_ap[:, 0] if len(self.all_ap) else []
|
|
|
@ -431,7 +476,7 @@ class Metric:
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def ap(self):
|
|
|
|
def ap(self):
|
|
|
|
"""AP@0.5:0.95
|
|
|
|
"""AP@0.5:0.95
|
|
|
|
Return:
|
|
|
|
Returns:
|
|
|
|
(nc, ) or [].
|
|
|
|
(nc, ) or [].
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self.all_ap.mean(1) if len(self.all_ap) else []
|
|
|
|
return self.all_ap.mean(1) if len(self.all_ap) else []
|
|
|
@ -439,7 +484,7 @@ class Metric:
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def mp(self):
|
|
|
|
def mp(self):
|
|
|
|
"""mean precision of all classes.
|
|
|
|
"""mean precision of all classes.
|
|
|
|
Return:
|
|
|
|
Returns:
|
|
|
|
float.
|
|
|
|
float.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self.p.mean() if len(self.p) else 0.0
|
|
|
|
return self.p.mean() if len(self.p) else 0.0
|
|
|
@ -447,7 +492,7 @@ class Metric:
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def mr(self):
|
|
|
|
def mr(self):
|
|
|
|
"""mean recall of all classes.
|
|
|
|
"""mean recall of all classes.
|
|
|
|
Return:
|
|
|
|
Returns:
|
|
|
|
float.
|
|
|
|
float.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self.r.mean() if len(self.r) else 0.0
|
|
|
|
return self.r.mean() if len(self.r) else 0.0
|
|
|
@ -455,7 +500,7 @@ class Metric:
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def map50(self):
|
|
|
|
def map50(self):
|
|
|
|
"""Mean AP@0.5 of all classes.
|
|
|
|
"""Mean AP@0.5 of all classes.
|
|
|
|
Return:
|
|
|
|
Returns:
|
|
|
|
float.
|
|
|
|
float.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
|
|
|
|
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
|
|
|
@ -463,7 +508,7 @@ class Metric:
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def map75(self):
|
|
|
|
def map75(self):
|
|
|
|
"""Mean AP@0.75 of all classes.
|
|
|
|
"""Mean AP@0.75 of all classes.
|
|
|
|
Return:
|
|
|
|
Returns:
|
|
|
|
float.
|
|
|
|
float.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
|
|
|
|
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
|
|
|
@ -471,7 +516,7 @@ class Metric:
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def map(self):
|
|
|
|
def map(self):
|
|
|
|
"""Mean AP@0.5:0.95 of all classes.
|
|
|
|
"""Mean AP@0.5:0.95 of all classes.
|
|
|
|
Return:
|
|
|
|
Returns:
|
|
|
|
float.
|
|
|
|
float.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self.all_ap.mean() if len(self.all_ap) else 0.0
|
|
|
|
return self.all_ap.mean() if len(self.all_ap) else 0.0
|
|
|
@ -506,6 +551,32 @@ class Metric:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DetMetrics:
|
|
|
|
class DetMetrics:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
|
|
|
|
|
|
|
|
(mAP) of an object detection model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
|
|
|
|
|
|
|
|
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
|
|
|
|
|
|
|
|
names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
|
|
|
save_dir (Path): A path to the directory where the output plots will be saved.
|
|
|
|
|
|
|
|
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
|
|
|
|
|
|
|
|
names (tuple of str): A tuple of strings that represents the names of the classes.
|
|
|
|
|
|
|
|
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
|
|
|
|
|
|
|
|
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Methods:
|
|
|
|
|
|
|
|
process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions.
|
|
|
|
|
|
|
|
keys: Returns a list of keys for accessing the computed detection metrics.
|
|
|
|
|
|
|
|
mean_results: Returns a list of mean values for the computed detection metrics.
|
|
|
|
|
|
|
|
class_result(i): Returns a list of values for the computed detection metrics for a specific class.
|
|
|
|
|
|
|
|
maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds.
|
|
|
|
|
|
|
|
fitness: Computes the fitness score based on the computed detection metrics.
|
|
|
|
|
|
|
|
ap_class_index: Returns a list of class indices sorted by their average precision (AP) values.
|
|
|
|
|
|
|
|
results_dict: Returns a dictionary that maps detection metric keys to their computed values.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
|
|
|
|
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
|
|
|
|
self.save_dir = save_dir
|
|
|
|
self.save_dir = save_dir
|
|
|
@ -514,6 +585,10 @@ class DetMetrics:
|
|
|
|
self.box = Metric()
|
|
|
|
self.box = Metric()
|
|
|
|
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
|
|
|
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, attr):
|
|
|
|
|
|
|
|
name = self.__class__.__name__
|
|
|
|
|
|
|
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
|
|
|
|
|
|
|
|
|
|
def process(self, tp, conf, pred_cls, target_cls):
|
|
|
|
def process(self, tp, conf, pred_cls, target_cls):
|
|
|
|
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
|
|
|
|
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
|
|
|
|
names=self.names)[2:]
|
|
|
|
names=self.names)[2:]
|
|
|
@ -548,6 +623,31 @@ class DetMetrics:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SegmentMetrics:
|
|
|
|
class SegmentMetrics:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Calculates and aggregates detection and segmentation metrics over a given set of classes.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
|
|
|
|
|
|
|
|
plot (bool): Whether to save the detection and segmentation plots. Default is False.
|
|
|
|
|
|
|
|
names (list): List of class names. Default is an empty list.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
|
|
|
save_dir (Path): Path to the directory where the output plots should be saved.
|
|
|
|
|
|
|
|
plot (bool): Whether to save the detection and segmentation plots.
|
|
|
|
|
|
|
|
names (list): List of class names.
|
|
|
|
|
|
|
|
box (Metric): An instance of the Metric class to calculate box detection metrics.
|
|
|
|
|
|
|
|
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
|
|
|
|
|
|
|
speed (dict): Dictionary to store the time taken in different phases of inference.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Methods:
|
|
|
|
|
|
|
|
process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
|
|
|
|
|
|
|
|
mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
|
|
|
|
|
|
|
|
class_result(i): Returns the detection and segmentation metrics of class `i`.
|
|
|
|
|
|
|
|
maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
|
|
|
|
|
|
|
|
fitness: Returns the fitness scores, which are a single weighted combination of metrics.
|
|
|
|
|
|
|
|
ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
|
|
|
|
|
|
|
|
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
|
|
|
|
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
|
|
|
|
self.save_dir = save_dir
|
|
|
|
self.save_dir = save_dir
|
|
|
@ -557,7 +657,22 @@ class SegmentMetrics:
|
|
|
|
self.seg = Metric()
|
|
|
|
self.seg = Metric()
|
|
|
|
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
|
|
|
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, attr):
|
|
|
|
|
|
|
|
name = self.__class__.__name__
|
|
|
|
|
|
|
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
|
|
|
|
|
|
|
|
|
|
def process(self, tp_m, tp_b, conf, pred_cls, target_cls):
|
|
|
|
def process(self, tp_m, tp_b, conf, pred_cls, target_cls):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Processes the detection and segmentation metrics over the given set of predictions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
tp_m (list): List of True Positive masks.
|
|
|
|
|
|
|
|
tp_b (list): List of True Positive boxes.
|
|
|
|
|
|
|
|
conf (list): List of confidence scores.
|
|
|
|
|
|
|
|
pred_cls (list): List of predicted classes.
|
|
|
|
|
|
|
|
target_cls (list): List of target classes.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
results_mask = ap_per_class(tp_m,
|
|
|
|
results_mask = ap_per_class(tp_m,
|
|
|
|
conf,
|
|
|
|
conf,
|
|
|
|
pred_cls,
|
|
|
|
pred_cls,
|
|
|
@ -610,12 +725,32 @@ class SegmentMetrics:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClassifyMetrics:
|
|
|
|
class ClassifyMetrics:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Class for computing classification metrics including top-1 and top-5 accuracy.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
|
|
|
top1 (float): The top-1 accuracy.
|
|
|
|
|
|
|
|
top5 (float): The top-5 accuracy.
|
|
|
|
|
|
|
|
speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Properties:
|
|
|
|
|
|
|
|
fitness (float): The fitness of the model, which is equal to top-5 accuracy.
|
|
|
|
|
|
|
|
results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.
|
|
|
|
|
|
|
|
keys (List[str]): A list of keys for the results_dict.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Methods:
|
|
|
|
|
|
|
|
process(targets, pred): Processes the targets and predictions to compute classification metrics.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self.top1 = 0
|
|
|
|
self.top1 = 0
|
|
|
|
self.top5 = 0
|
|
|
|
self.top5 = 0
|
|
|
|
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
|
|
|
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, attr):
|
|
|
|
|
|
|
|
name = self.__class__.__name__
|
|
|
|
|
|
|
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
|
|
|
|
|
|
|
|
|
|
def process(self, targets, pred):
|
|
|
|
def process(self, targets, pred):
|
|
|
|
# target classes and predicted classes
|
|
|
|
# target classes and predicted classes
|
|
|
|
pred, targets = torch.cat(pred), torch.cat(targets)
|
|
|
|
pred, targets = torch.cat(pred), torch.cat(targets)
|
|
|
|