|
|
|
@ -307,8 +307,6 @@ class RTDETRDecoder(nn.Module):
|
|
|
|
|
features = self.enc_output(valid_mask * feats) # bs, h*w, 256
|
|
|
|
|
|
|
|
|
|
enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
|
|
|
|
|
# dynamic anchors + static content
|
|
|
|
|
enc_outputs_bboxes = self.enc_bbox_head(features) + anchors # (bs, h*w, 4)
|
|
|
|
|
|
|
|
|
|
# query selection
|
|
|
|
|
# (bs, num_queries)
|
|
|
|
@ -316,22 +314,23 @@ class RTDETRDecoder(nn.Module):
|
|
|
|
|
# (bs, num_queries)
|
|
|
|
|
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
|
|
|
|
|
|
|
|
|
|
# Unsigmoided
|
|
|
|
|
refer_bbox = enc_outputs_bboxes[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
|
|
|
|
# refer_bbox = torch.gather(enc_outputs_bboxes, 1, topk_ind.reshape(bs, self.num_queries).unsqueeze(-1).repeat(1, 1, 4))
|
|
|
|
|
# (bs, num_queries, 256)
|
|
|
|
|
top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
|
|
|
|
# (bs, num_queries, 4)
|
|
|
|
|
top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)
|
|
|
|
|
|
|
|
|
|
# dynamic anchors + static content
|
|
|
|
|
refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors
|
|
|
|
|
|
|
|
|
|
enc_bboxes = refer_bbox.sigmoid()
|
|
|
|
|
if dn_bbox is not None:
|
|
|
|
|
refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
|
|
|
|
|
if self.training:
|
|
|
|
|
refer_bbox = refer_bbox.detach()
|
|
|
|
|
enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
|
|
|
|
|
|
|
|
|
if self.learnt_init_query:
|
|
|
|
|
embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
|
|
|
|
|
else:
|
|
|
|
|
embeddings = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
|
|
|
|
embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features
|
|
|
|
|
if self.training:
|
|
|
|
|
refer_bbox = refer_bbox.detach()
|
|
|
|
|
if not self.learnt_init_query:
|
|
|
|
|
embeddings = embeddings.detach()
|
|
|
|
|
if dn_embed is not None:
|
|
|
|
|
embeddings = torch.cat([dn_embed, embeddings], 1)
|
|
|
|
|