|
|
|
@ -71,7 +71,7 @@ class HungarianMatcher(nn.Module):
|
|
|
|
|
bs, nq, nc = pred_scores.shape
|
|
|
|
|
|
|
|
|
|
if sum(gt_groups) == 0:
|
|
|
|
|
return [(torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)) for _ in range(bs)]
|
|
|
|
|
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
|
|
|
|
|
|
|
|
|
|
# We flatten to compute the cost matrices in a batch
|
|
|
|
|
# [batch_size * num_queries, num_classes]
|
|
|
|
@ -107,7 +107,7 @@ class HungarianMatcher(nn.Module):
|
|
|
|
|
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
|
|
|
|
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
|
|
|
|
# (idx for queries, idx for gt)
|
|
|
|
|
return [(torch.tensor(i, dtype=torch.int32), torch.tensor(j, dtype=torch.int32) + gt_groups[k])
|
|
|
|
|
return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
|
|
|
|
|
for k, (i, j) in enumerate(indices)]
|
|
|
|
|
|
|
|
|
|
def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
|
|
|
|