ultralytics 8.0.105
classification hyp fix and new onplot
callbacks (#2684)
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
This commit is contained in:
@ -300,7 +300,7 @@ class ConfusionMatrix:
|
||||
|
||||
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
|
||||
@plt_settings()
|
||||
def plot(self, normalize=True, save_dir='', names=()):
|
||||
def plot(self, normalize=True, save_dir='', names=(), on_plot=None):
|
||||
"""
|
||||
Plot the confusion matrix using seaborn and save it to a file.
|
||||
|
||||
@ -308,6 +308,7 @@ class ConfusionMatrix:
|
||||
normalize (bool): Whether to normalize the confusion matrix.
|
||||
save_dir (str): Directory where the plot will be saved.
|
||||
names (tuple): Names of classes, used as labels on the plot.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
||||
"""
|
||||
import seaborn as sn
|
||||
|
||||
@ -336,8 +337,11 @@ class ConfusionMatrix:
|
||||
ax.set_xlabel('True')
|
||||
ax.set_ylabel('Predicted')
|
||||
ax.set_title(title)
|
||||
fig.savefig(Path(save_dir) / f'{title.lower().replace(" ", "_")}.png', dpi=250)
|
||||
plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png'
|
||||
fig.savefig(plot_fname, dpi=250)
|
||||
plt.close(fig)
|
||||
if on_plot:
|
||||
on_plot(plot_fname)
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
@ -356,7 +360,7 @@ def smooth(y, f=0.05):
|
||||
|
||||
|
||||
@plt_settings()
|
||||
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
||||
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=None):
|
||||
"""Plots a precision-recall curve."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
py = np.stack(py, axis=1)
|
||||
@ -376,10 +380,12 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
||||
ax.set_title('Precision-Recall Curve')
|
||||
fig.savefig(save_dir, dpi=250)
|
||||
plt.close(fig)
|
||||
if on_plot:
|
||||
on_plot(save_dir)
|
||||
|
||||
|
||||
@plt_settings()
|
||||
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
|
||||
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric', on_plot=None):
|
||||
"""Plots a metric-confidence curve."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
|
||||
@ -399,6 +405,8 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
|
||||
ax.set_title(f'{ylabel}-Confidence Curve')
|
||||
fig.savefig(save_dir, dpi=250)
|
||||
plt.close(fig)
|
||||
if on_plot:
|
||||
on_plot(save_dir)
|
||||
|
||||
|
||||
def compute_ap(recall, precision):
|
||||
@ -434,7 +442,16 @@ def compute_ap(recall, precision):
|
||||
return ap, mpre, mrec
|
||||
|
||||
|
||||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''):
|
||||
def ap_per_class(tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=False,
|
||||
on_plot=None,
|
||||
save_dir=Path(),
|
||||
names=(),
|
||||
eps=1e-16,
|
||||
prefix=''):
|
||||
"""
|
||||
Computes the average precision per class for object detection evaluation.
|
||||
|
||||
@ -444,6 +461,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
|
||||
pred_cls (np.ndarray): Array of predicted classes of the detections.
|
||||
target_cls (np.ndarray): Array of true classes of the detections.
|
||||
plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
|
||||
on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
|
||||
save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
|
||||
names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
|
||||
@ -502,10 +520,10 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
|
||||
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
|
||||
names = dict(enumerate(names)) # to dict
|
||||
if plot:
|
||||
plot_pr_curve(px, py, ap, save_dir / f'{prefix}PR_curve.png', names)
|
||||
plot_mc_curve(px, f1, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1')
|
||||
plot_mc_curve(px, p, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision')
|
||||
plot_mc_curve(px, r, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall')
|
||||
plot_pr_curve(px, py, ap, save_dir / f'{prefix}PR_curve.png', names, on_plot=on_plot)
|
||||
plot_mc_curve(px, f1, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1', on_plot=on_plot)
|
||||
plot_mc_curve(px, p, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision', on_plot=on_plot)
|
||||
plot_mc_curve(px, r, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall', on_plot=on_plot)
|
||||
|
||||
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
|
||||
p, r, f1 = p[:, i], r[:, i], f1[:, i]
|
||||
@ -657,11 +675,13 @@ class DetMetrics(SimpleClass):
|
||||
Args:
|
||||
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
|
||||
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
|
||||
names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple.
|
||||
|
||||
Attributes:
|
||||
save_dir (Path): A path to the directory where the output plots will be saved.
|
||||
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
||||
names (tuple of str): A tuple of strings that represents the names of the classes.
|
||||
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
|
||||
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
|
||||
@ -677,9 +697,10 @@ class DetMetrics(SimpleClass):
|
||||
results_dict: Returns a dictionary that maps detection metric keys to their computed values.
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.on_plot = on_plot
|
||||
self.names = names
|
||||
self.box = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
@ -732,11 +753,13 @@ class SegmentMetrics(SimpleClass):
|
||||
Args:
|
||||
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
|
||||
plot (bool): Whether to save the detection and segmentation plots. Default is False.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
|
||||
names (list): List of class names. Default is an empty list.
|
||||
|
||||
Attributes:
|
||||
save_dir (Path): Path to the directory where the output plots should be saved.
|
||||
plot (bool): Whether to save the detection and segmentation plots.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
||||
names (list): List of class names.
|
||||
box (Metric): An instance of the Metric class to calculate box detection metrics.
|
||||
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
||||
@ -752,9 +775,10 @@ class SegmentMetrics(SimpleClass):
|
||||
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.on_plot = on_plot
|
||||
self.names = names
|
||||
self.box = Metric()
|
||||
self.seg = Metric()
|
||||
@ -777,6 +801,7 @@ class SegmentMetrics(SimpleClass):
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Mask')[2:]
|
||||
@ -787,6 +812,7 @@ class SegmentMetrics(SimpleClass):
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Box')[2:]
|
||||
@ -836,11 +862,13 @@ class PoseMetrics(SegmentMetrics):
|
||||
Args:
|
||||
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
|
||||
plot (bool): Whether to save the detection and segmentation plots. Default is False.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
|
||||
names (list): List of class names. Default is an empty list.
|
||||
|
||||
Attributes:
|
||||
save_dir (Path): Path to the directory where the output plots should be saved.
|
||||
plot (bool): Whether to save the detection and segmentation plots.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
||||
names (list): List of class names.
|
||||
box (Metric): An instance of the Metric class to calculate box detection metrics.
|
||||
pose (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
||||
@ -856,10 +884,11 @@ class PoseMetrics(SegmentMetrics):
|
||||
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
super().__init__(save_dir, plot, names)
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.on_plot = on_plot
|
||||
self.names = names
|
||||
self.box = Metric()
|
||||
self.pose = Metric()
|
||||
@ -887,6 +916,7 @@ class PoseMetrics(SegmentMetrics):
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Pose')[2:]
|
||||
@ -897,6 +927,7 @@ class PoseMetrics(SegmentMetrics):
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Box')[2:]
|
||||
|
@ -228,7 +228,7 @@ class Annotator:
|
||||
|
||||
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||
@plt_settings()
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path('')):
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
|
||||
"""Save and plot image with no axis or spines."""
|
||||
import pandas as pd
|
||||
import seaborn as sn
|
||||
@ -271,8 +271,11 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
|
||||
for s in ['top', 'right', 'left', 'bottom']:
|
||||
ax[a].spines[s].set_visible(False)
|
||||
|
||||
plt.savefig(save_dir / 'labels.jpg', dpi=200)
|
||||
fname = save_dir / 'labels.jpg'
|
||||
plt.savefig(fname, dpi=200)
|
||||
plt.close()
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
|
||||
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
||||
@ -301,7 +304,8 @@ def plot_images(images,
|
||||
kpts=np.zeros((0, 51), dtype=np.float32),
|
||||
paths=None,
|
||||
fname='images.jpg',
|
||||
names=None):
|
||||
names=None,
|
||||
on_plot=None):
|
||||
# Plot image grid with labels
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.cpu().float().numpy()
|
||||
@ -419,10 +423,12 @@ def plot_images(images,
|
||||
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||
annotator.fromarray(im)
|
||||
annotator.im.save(fname) # save
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
|
||||
@plt_settings()
|
||||
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False):
|
||||
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
|
||||
"""Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')."""
|
||||
import pandas as pd
|
||||
save_dir = Path(file).parent if file else Path(dir)
|
||||
@ -456,8 +462,11 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
|
||||
ax[1].legend()
|
||||
fig.savefig(save_dir / 'results.png', dpi=200)
|
||||
fname = save_dir / 'results.png'
|
||||
fig.savefig(fname, dpi=200)
|
||||
plt.close()
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
|
||||
def output_to_target(output, max_det=300):
|
||||
|
Reference in New Issue
Block a user