You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
231 lines
10 KiB
231 lines
10 KiB
1 year ago
|
# TODO: license
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from scipy.optimize import linear_sum_assignment
|
||
|
|
||
|
from ultralytics.yolo.utils.metrics import bbox_iou
|
||
|
from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh
|
||
|
|
||
|
|
||
|
class HungarianMatcher(nn.Module):
|
||
|
|
||
|
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
|
||
|
"""
|
||
|
Args:
|
||
|
matcher_coeff (dict): The coefficient of hungarian matcher cost.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
if cost_gain is None:
|
||
|
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
|
||
|
self.cost_gain = cost_gain
|
||
|
self.use_fl = use_fl
|
||
|
self.with_mask = with_mask
|
||
|
self.num_sample_points = num_sample_points
|
||
|
self.alpha = alpha
|
||
|
self.gamma = gamma
|
||
|
|
||
|
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
|
||
|
"""
|
||
|
Args:
|
||
|
pred_bboxes (Tensor): [b, query, 4]
|
||
|
pred_scores (Tensor): [b, query, num_classes]
|
||
|
gt_cls (torch.Tensor) with shape [num_gts, ]
|
||
|
gt_bboxes (torch.Tensor): [num_gts, 4]
|
||
|
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||
|
masks (Tensor|None): [b, query, h, w]
|
||
|
gt_mask (List(Tensor)): list[[n, H, W]]
|
||
|
|
||
|
Returns:
|
||
|
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
||
|
- index_i is the indices of the selected predictions (in order)
|
||
|
- index_j is the indices of the corresponding selected targets (in order)
|
||
|
For each batch element, it holds:
|
||
|
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||
|
"""
|
||
|
bs, nq, nc = pred_scores.shape
|
||
|
|
||
|
if sum(gt_groups) == 0:
|
||
|
return [(torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)) for _ in range(bs)]
|
||
|
|
||
|
# We flatten to compute the cost matrices in a batch
|
||
|
# [batch_size * num_queries, num_classes]
|
||
|
pred_scores = pred_scores.detach().view(-1, nc)
|
||
|
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
|
||
|
# [batch_size * num_queries, 4]
|
||
|
pred_bboxes = pred_bboxes.detach().view(-1, 4)
|
||
|
|
||
|
# Compute the classification cost
|
||
|
pred_scores = pred_scores[:, gt_cls]
|
||
|
if self.use_fl:
|
||
|
neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
||
|
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
|
||
|
cost_class = pos_cost_class - neg_cost_class
|
||
|
else:
|
||
|
cost_class = -pred_scores
|
||
|
|
||
|
# Compute the L1 cost between boxes
|
||
|
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
|
||
|
|
||
|
# Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
|
||
|
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
||
|
|
||
|
# Final cost matrix
|
||
|
C = self.cost_gain['class'] * cost_class + \
|
||
|
self.cost_gain['bbox'] * cost_bbox + \
|
||
|
self.cost_gain['giou'] * cost_giou
|
||
|
# Compute the mask cost and dice cost
|
||
|
if self.with_mask:
|
||
|
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
||
|
|
||
|
C = C.view(bs, nq, -1).cpu()
|
||
|
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
||
|
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||
|
# (idx for queries, idx for gt)
|
||
|
return [(torch.tensor(i, dtype=torch.int32), torch.tensor(j, dtype=torch.int32) + gt_groups[k])
|
||
|
for k, (i, j) in enumerate(indices)]
|
||
|
|
||
|
def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
||
|
assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
|
||
|
# all masks share the same set of points for efficient matching
|
||
|
sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
|
||
|
sample_points = 2.0 * sample_points - 1.0
|
||
|
|
||
|
out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
|
||
|
out_mask = out_mask.flatten(0, 1)
|
||
|
|
||
|
tgt_mask = torch.cat(gt_mask).unsqueeze(1)
|
||
|
sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
|
||
|
tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
|
||
|
|
||
|
with torch.cuda.amp.autocast(False):
|
||
|
# binary cross entropy cost
|
||
|
pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
|
||
|
neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
|
||
|
cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
|
||
|
cost_mask /= self.num_sample_points
|
||
|
|
||
|
# dice cost
|
||
|
out_mask = F.sigmoid(out_mask)
|
||
|
numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
|
||
|
denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
|
||
|
cost_dice = 1 - (numerator + 1) / (denominator + 1)
|
||
|
|
||
|
C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
|
||
|
return C
|
||
|
|
||
|
|
||
|
def get_cdn_group(batch,
|
||
|
num_classes,
|
||
|
num_queries,
|
||
|
class_embed,
|
||
|
num_dn=100,
|
||
|
cls_noise_ratio=0.5,
|
||
|
box_noise_scale=1.0,
|
||
|
training=False):
|
||
|
"""Get contrastive denoising training group
|
||
|
|
||
|
Args:
|
||
|
batch (dict): A dict includes:
|
||
|
gt_cls (torch.Tensor) with shape [num_gts, ],
|
||
|
gt_bboxes (torch.Tensor): [num_gts, 4],
|
||
|
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||
|
num_classes (int): Number of classes.
|
||
|
num_queries (int): Number of queries.
|
||
|
class_embed (torch.Tensor): Embedding weights to map cls to embedding space.
|
||
|
num_dn (int): Number of denoising.
|
||
|
cls_noise_ratio (float): Noise ratio for class.
|
||
|
box_noise_scale (float): Noise scale for bbox.
|
||
|
training (bool): If it's training or not.
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
"""
|
||
|
if (not training) or num_dn <= 0:
|
||
|
return None, None, None, None
|
||
|
gt_groups = batch['gt_groups']
|
||
|
total_num = sum(gt_groups)
|
||
|
max_nums = max(gt_groups)
|
||
|
if max_nums == 0:
|
||
|
return None, None, None, None
|
||
|
|
||
|
num_group = num_dn // max_nums
|
||
|
num_group = 1 if num_group == 0 else num_group
|
||
|
# pad gt to max_num of a batch
|
||
|
bs = len(gt_groups)
|
||
|
gt_cls = batch['cls'] # (bs*num, )
|
||
|
gt_bbox = batch['bboxes'] # bs*num, 4
|
||
|
b_idx = batch['batch_idx']
|
||
|
|
||
|
# each group has positive and negative queries.
|
||
|
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
||
|
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
||
|
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
||
|
|
||
|
# positive and negative mask
|
||
|
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
|
||
|
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
||
|
|
||
|
if cls_noise_ratio > 0:
|
||
|
# half of bbox prob
|
||
|
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
|
||
|
idx = torch.nonzero(mask).squeeze(-1)
|
||
|
# randomly put a new one here
|
||
|
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
|
||
|
dn_cls[idx] = new_label
|
||
|
|
||
|
if box_noise_scale > 0:
|
||
|
known_bbox = xywh2xyxy(dn_bbox)
|
||
|
|
||
|
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
|
||
|
|
||
|
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
|
||
|
rand_part = torch.rand_like(dn_bbox)
|
||
|
rand_part[neg_idx] += 1.0
|
||
|
rand_part *= rand_sign
|
||
|
known_bbox += rand_part * diff
|
||
|
known_bbox.clip_(min=0.0, max=1.0)
|
||
|
dn_bbox = xyxy2xywh(known_bbox)
|
||
|
dn_bbox = inverse_sigmoid(dn_bbox)
|
||
|
|
||
|
# total denoising queries
|
||
|
num_dn = int(max_nums * 2 * num_group)
|
||
|
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
|
||
|
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
||
|
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
||
|
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
|
||
|
|
||
|
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
|
||
|
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
|
||
|
|
||
|
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
|
||
|
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
|
||
|
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
|
||
|
|
||
|
tgt_size = num_dn + num_queries
|
||
|
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
|
||
|
# match query cannot see the reconstruct
|
||
|
attn_mask[num_dn:, :num_dn] = True
|
||
|
# reconstruct cannot see each other
|
||
|
for i in range(num_group):
|
||
|
if i == 0:
|
||
|
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
||
|
if i == num_group - 1:
|
||
|
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
|
||
|
else:
|
||
|
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
||
|
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
|
||
|
dn_meta = {
|
||
|
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split([n for n in gt_groups], dim=1)],
|
||
|
'dn_num_group': num_group,
|
||
|
'dn_num_split': [num_dn, num_queries]}
|
||
|
|
||
|
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
|
||
|
class_embed.device), dn_meta
|
||
|
|
||
|
|
||
|
def inverse_sigmoid(x, eps=1e-6):
|
||
|
x = x.clip(min=0., max=1.)
|
||
|
return torch.log(x / (1 - x + eps) + eps)
|