ultralytics 8.0.96
TAL speed and memory improvements (#2484)
Signed-off-by: Evangelos Petrongonas <e.petrongonas@hellenicdrones.com> Co-authored-by: Evangelos Petrongonas <24351757+vpetrog@users.noreply.github.com> Co-authored-by: JF Chen <k-2feng@hotmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -84,9 +84,15 @@ def on_pretrain_routine_start(trainer):
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Logs debug samples for the first epoch of YOLO training."""
|
||||
if trainer.epoch == 1 and Task.current_task():
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
|
||||
task = Task.current_task()
|
||||
|
||||
if task:
|
||||
"""Logs debug samples for the first epoch of YOLO training."""
|
||||
if trainer.epoch == 1:
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
|
||||
"""Report the current training progress."""
|
||||
for k, v in trainer.validator.metrics.results_dict.items():
|
||||
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
@ -119,7 +125,9 @@ def on_train_end(trainer):
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||
files = [
|
||||
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
|
||||
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
|
@ -87,7 +87,9 @@ def on_train_end(trainer):
|
||||
"""Callback function called at end of training."""
|
||||
if run:
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||
files = [
|
||||
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
|
||||
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
|
@ -321,17 +321,18 @@ class ConfusionMatrix:
|
||||
ticklabels = (list(names) + ['background']) if labels else 'auto'
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||
sn.heatmap(array,
|
||||
ax=ax,
|
||||
annot=nc < 30,
|
||||
annot_kws={
|
||||
'size': 8},
|
||||
cmap='Blues',
|
||||
fmt='.2f',
|
||||
square=True,
|
||||
vmin=0.0,
|
||||
xticklabels=ticklabels,
|
||||
yticklabels=ticklabels).set_facecolor((1, 1, 1))
|
||||
sn.heatmap(
|
||||
array,
|
||||
ax=ax,
|
||||
annot=nc < 30,
|
||||
annot_kws={
|
||||
'size': 8},
|
||||
cmap='Blues',
|
||||
fmt='.2f' if normalize else '%d', # float if normalize else integer
|
||||
square=True,
|
||||
vmin=0.0,
|
||||
xticklabels=ticklabels,
|
||||
yticklabels=ticklabels).set_facecolor((1, 1, 1))
|
||||
title = 'Confusion Matrix' + ' Normalized' * normalize
|
||||
ax.set_xlabel('True')
|
||||
ax.set_ylabel('Predicted')
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .checks import check_version
|
||||
from .metrics import bbox_iou
|
||||
@ -44,9 +43,11 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1]) # (b, n_max_boxes, h*w)
|
||||
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
||||
is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes) # (b, h*w, n_max_boxes)
|
||||
is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) # (b, n_max_boxes, h*w)
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos) # (b, n_max_boxes, h*w)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
||||
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
# Find each grid serve which gt(index)
|
||||
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
|
||||
@ -175,21 +176,23 @@ class TaskAlignedAssigner(nn.Module):
|
||||
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
||||
"""
|
||||
|
||||
num_anchors = metrics.shape[-1] # h*w
|
||||
# (b, max_num_obj, topk)
|
||||
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
|
||||
if topk_mask is None:
|
||||
topk_mask = (topk_metrics.max(-1, keepdim=True) > self.eps).tile([1, 1, self.topk])
|
||||
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
|
||||
# (b, max_num_obj, topk)
|
||||
topk_idxs[~topk_mask] = 0
|
||||
topk_idxs.masked_fill_(~topk_mask, 0)
|
||||
|
||||
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
|
||||
is_in_topk = torch.zeros(metrics.shape, dtype=torch.long, device=metrics.device)
|
||||
for it in range(self.topk):
|
||||
is_in_topk += F.one_hot(topk_idxs[:, :, it], num_anchors)
|
||||
# is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
|
||||
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
|
||||
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
|
||||
for k in range(self.topk):
|
||||
# Expand topk_idxs for each value of k and add 1 at the specified positions
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
|
||||
# filter invalid bboxes
|
||||
is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
|
||||
return is_in_topk.to(metrics.dtype)
|
||||
count_tensor.masked_fill_(count_tensor > 1, 0)
|
||||
|
||||
return count_tensor.to(metrics.dtype)
|
||||
|
||||
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
||||
"""
|
||||
@ -226,7 +229,13 @@ class TaskAlignedAssigner(nn.Module):
|
||||
|
||||
# Assigned target scores
|
||||
target_labels.clamp_(0)
|
||||
target_scores = F.one_hot(target_labels, self.num_classes) # (b, h*w, 80)
|
||||
|
||||
# 10x faster than F.one_hot()
|
||||
target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
||||
dtype=torch.int64,
|
||||
device=target_labels.device) # (b, h*w, 80)
|
||||
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
||||
|
||||
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
|
||||
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
|
||||
|
||||
|
Reference in New Issue
Block a user