diff --git a/ultralytics/vit/utils/loss.py b/ultralytics/vit/utils/loss.py index 6ba24c2..cb2de20 100644 --- a/ultralytics/vit/utils/loss.py +++ b/ultralytics/vit/utils/loss.py @@ -284,11 +284,11 @@ class RTDETRDetectionLoss(DETRLoss): idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) for i, num_gt in enumerate(gt_groups): if num_gt > 0: - gt_idx = torch.arange(end=num_gt, dtype=torch.int32) + idx_groups[i] + gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i] gt_idx = gt_idx.repeat(dn_num_group) assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, ' f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.' dn_match_indices.append((dn_pos_idx[i], gt_idx)) else: - dn_match_indices.append((torch.zeros([0], dtype=torch.int32), torch.zeros([0], dtype=torch.int32))) + dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long))) return dn_match_indices diff --git a/ultralytics/vit/utils/ops.py b/ultralytics/vit/utils/ops.py index e8978a5..4b37931 100644 --- a/ultralytics/vit/utils/ops.py +++ b/ultralytics/vit/utils/ops.py @@ -71,7 +71,7 @@ class HungarianMatcher(nn.Module): bs, nq, nc = pred_scores.shape if sum(gt_groups) == 0: - return [(torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)) for _ in range(bs)] + return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)] # We flatten to compute the cost matrices in a batch # [batch_size * num_queries, num_classes] @@ -107,7 +107,7 @@ class HungarianMatcher(nn.Module): indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))] gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt) - return [(torch.tensor(i, dtype=torch.int32), torch.tensor(j, dtype=torch.int32) + gt_groups[k]) + return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k]) for k, (i, j) in enumerate(indices)] def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):