ultralytics 8.0.105 classification hyp fix and new onplot callbacks (#2684)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
This commit is contained in:
Glenn Jocher
2023-05-17 19:10:20 +02:00
committed by GitHub
parent b1119d512e
commit 23fc50641c
92 changed files with 378 additions and 206 deletions

View File

@ -300,7 +300,7 @@ class ConfusionMatrix:
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
@plt_settings()
def plot(self, normalize=True, save_dir='', names=()):
def plot(self, normalize=True, save_dir='', names=(), on_plot=None):
"""
Plot the confusion matrix using seaborn and save it to a file.
@ -308,6 +308,7 @@ class ConfusionMatrix:
normalize (bool): Whether to normalize the confusion matrix.
save_dir (str): Directory where the plot will be saved.
names (tuple): Names of classes, used as labels on the plot.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
"""
import seaborn as sn
@ -336,8 +337,11 @@ class ConfusionMatrix:
ax.set_xlabel('True')
ax.set_ylabel('Predicted')
ax.set_title(title)
fig.savefig(Path(save_dir) / f'{title.lower().replace(" ", "_")}.png', dpi=250)
plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png'
fig.savefig(plot_fname, dpi=250)
plt.close(fig)
if on_plot:
on_plot(plot_fname)
def print(self):
"""
@ -356,7 +360,7 @@ def smooth(y, f=0.05):
@plt_settings()
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=None):
"""Plots a precision-recall curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)
@ -376,10 +380,12 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
ax.set_title('Precision-Recall Curve')
fig.savefig(save_dir, dpi=250)
plt.close(fig)
if on_plot:
on_plot(save_dir)
@plt_settings()
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric', on_plot=None):
"""Plots a metric-confidence curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@ -399,6 +405,8 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
ax.set_title(f'{ylabel}-Confidence Curve')
fig.savefig(save_dir, dpi=250)
plt.close(fig)
if on_plot:
on_plot(save_dir)
def compute_ap(recall, precision):
@ -434,7 +442,16 @@ def compute_ap(recall, precision):
return ap, mpre, mrec
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''):
def ap_per_class(tp,
conf,
pred_cls,
target_cls,
plot=False,
on_plot=None,
save_dir=Path(),
names=(),
eps=1e-16,
prefix=''):
"""
Computes the average precision per class for object detection evaluation.
@ -444,6 +461,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
pred_cls (np.ndarray): Array of predicted classes of the detections.
target_cls (np.ndarray): Array of true classes of the detections.
plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
@ -502,10 +520,10 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
names = dict(enumerate(names)) # to dict
if plot:
plot_pr_curve(px, py, ap, save_dir / f'{prefix}PR_curve.png', names)
plot_mc_curve(px, f1, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1')
plot_mc_curve(px, p, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision')
plot_mc_curve(px, r, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall')
plot_pr_curve(px, py, ap, save_dir / f'{prefix}PR_curve.png', names, on_plot=on_plot)
plot_mc_curve(px, f1, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1', on_plot=on_plot)
plot_mc_curve(px, p, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision', on_plot=on_plot)
plot_mc_curve(px, r, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall', on_plot=on_plot)
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p[:, i], r[:, i], f1[:, i]
@ -657,11 +675,13 @@ class DetMetrics(SimpleClass):
Args:
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple.
Attributes:
save_dir (Path): A path to the directory where the output plots will be saved.
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
names (tuple of str): A tuple of strings that represents the names of the classes.
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
@ -677,9 +697,10 @@ class DetMetrics(SimpleClass):
results_dict: Returns a dictionary that maps detection metric keys to their computed values.
"""
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
@ -732,11 +753,13 @@ class SegmentMetrics(SimpleClass):
Args:
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
plot (bool): Whether to save the detection and segmentation plots. Default is False.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
names (list): List of class names. Default is an empty list.
Attributes:
save_dir (Path): Path to the directory where the output plots should be saved.
plot (bool): Whether to save the detection and segmentation plots.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
names (list): List of class names.
box (Metric): An instance of the Metric class to calculate box detection metrics.
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
@ -752,9 +775,10 @@ class SegmentMetrics(SimpleClass):
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.seg = Metric()
@ -777,6 +801,7 @@ class SegmentMetrics(SimpleClass):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Mask')[2:]
@ -787,6 +812,7 @@ class SegmentMetrics(SimpleClass):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Box')[2:]
@ -836,11 +862,13 @@ class PoseMetrics(SegmentMetrics):
Args:
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
plot (bool): Whether to save the detection and segmentation plots. Default is False.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
names (list): List of class names. Default is an empty list.
Attributes:
save_dir (Path): Path to the directory where the output plots should be saved.
plot (bool): Whether to save the detection and segmentation plots.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
names (list): List of class names.
box (Metric): An instance of the Metric class to calculate box detection metrics.
pose (Metric): An instance of the Metric class to calculate mask segmentation metrics.
@ -856,10 +884,11 @@ class PoseMetrics(SegmentMetrics):
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
super().__init__(save_dir, plot, names)
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.pose = Metric()
@ -887,6 +916,7 @@ class PoseMetrics(SegmentMetrics):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Pose')[2:]
@ -897,6 +927,7 @@ class PoseMetrics(SegmentMetrics):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Box')[2:]

View File

@ -228,7 +228,7 @@ class Annotator:
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()
def plot_labels(boxes, cls, names=(), save_dir=Path('')):
def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
"""Save and plot image with no axis or spines."""
import pandas as pd
import seaborn as sn
@ -271,8 +271,11 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
for s in ['top', 'right', 'left', 'bottom']:
ax[a].spines[s].set_visible(False)
plt.savefig(save_dir / 'labels.jpg', dpi=200)
fname = save_dir / 'labels.jpg'
plt.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
@ -301,7 +304,8 @@ def plot_images(images,
kpts=np.zeros((0, 51), dtype=np.float32),
paths=None,
fname='images.jpg',
names=None):
names=None,
on_plot=None):
# Plot image grid with labels
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
@ -419,10 +423,12 @@ def plot_images(images,
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
annotator.fromarray(im)
annotator.im.save(fname) # save
if on_plot:
on_plot(fname)
@plt_settings()
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False):
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
"""Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')."""
import pandas as pd
save_dir = Path(file).parent if file else Path(dir)
@ -456,8 +462,11 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
except Exception as e:
LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
ax[1].legend()
fig.savefig(save_dir / 'results.png', dpi=200)
fname = save_dir / 'results.png'
fig.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
def output_to_target(output, max_det=300):