ultralytics 8.0.96 TAL speed and memory improvements (#2484)

Signed-off-by: Evangelos Petrongonas <e.petrongonas@hellenicdrones.com>
Co-authored-by: Evangelos Petrongonas <24351757+vpetrog@users.noreply.github.com>
Co-authored-by: JF Chen <k-2feng@hotmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-05-08 23:41:27 +02:00
committed by GitHub
parent e21428ca4e
commit 6ee3a9a74b
13 changed files with 163 additions and 53 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.94'
__version__ = '8.0.96'
from ultralytics.hub import start
from ultralytics.vit.sam import SAM

View File

@ -84,9 +84,15 @@ def on_pretrain_routine_start(trainer):
def on_train_epoch_end(trainer):
"""Logs debug samples for the first epoch of YOLO training."""
if trainer.epoch == 1 and Task.current_task():
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
task = Task.current_task()
if task:
"""Logs debug samples for the first epoch of YOLO training."""
if trainer.epoch == 1:
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
"""Report the current training progress."""
for k, v in trainer.validator.metrics.results_dict.items():
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
def on_fit_epoch_end(trainer):
@ -119,7 +125,9 @@ def on_train_end(trainer):
task = Task.current_task()
if task:
# Log final results, CM matrix + PR plots
files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
files = [
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)

View File

@ -87,7 +87,9 @@ def on_train_end(trainer):
"""Callback function called at end of training."""
if run:
# Log final results, CM matrix + PR plots
files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
files = [
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)

View File

@ -321,17 +321,18 @@ class ConfusionMatrix:
ticklabels = (list(names) + ['background']) if labels else 'auto'
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array,
ax=ax,
annot=nc < 30,
annot_kws={
'size': 8},
cmap='Blues',
fmt='.2f',
square=True,
vmin=0.0,
xticklabels=ticklabels,
yticklabels=ticklabels).set_facecolor((1, 1, 1))
sn.heatmap(
array,
ax=ax,
annot=nc < 30,
annot_kws={
'size': 8},
cmap='Blues',
fmt='.2f' if normalize else '%d', # float if normalize else integer
square=True,
vmin=0.0,
xticklabels=ticklabels,
yticklabels=ticklabels).set_facecolor((1, 1, 1))
title = 'Confusion Matrix' + ' Normalized' * normalize
ax.set_xlabel('True')
ax.set_ylabel('Predicted')

View File

@ -2,7 +2,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .checks import check_version
from .metrics import bbox_iou
@ -44,9 +43,11 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1]) # (b, n_max_boxes, h*w)
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes) # (b, h*w, n_max_boxes)
is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) # (b, n_max_boxes, h*w)
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos) # (b, n_max_boxes, h*w)
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
fg_mask = mask_pos.sum(-2)
# Find each grid serve which gt(index)
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
@ -175,21 +176,23 @@ class TaskAlignedAssigner(nn.Module):
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
"""
num_anchors = metrics.shape[-1] # h*w
# (b, max_num_obj, topk)
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
if topk_mask is None:
topk_mask = (topk_metrics.max(-1, keepdim=True) > self.eps).tile([1, 1, self.topk])
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
# (b, max_num_obj, topk)
topk_idxs[~topk_mask] = 0
topk_idxs.masked_fill_(~topk_mask, 0)
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
is_in_topk = torch.zeros(metrics.shape, dtype=torch.long, device=metrics.device)
for it in range(self.topk):
is_in_topk += F.one_hot(topk_idxs[:, :, it], num_anchors)
# is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
for k in range(self.topk):
# Expand topk_idxs for each value of k and add 1 at the specified positions
count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
# filter invalid bboxes
is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
return is_in_topk.to(metrics.dtype)
count_tensor.masked_fill_(count_tensor > 1, 0)
return count_tensor.to(metrics.dtype)
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
"""
@ -226,7 +229,13 @@ class TaskAlignedAssigner(nn.Module):
# Assigned target scores
target_labels.clamp_(0)
target_scores = F.one_hot(target_labels, self.num_classes) # (b, h*w, 80)
# 10x faster than F.one_hot()
target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
dtype=torch.int64,
device=target_labels.device) # (b, h*w, 80)
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)