diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 4e3661b..fffd102 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -307,8 +307,6 @@ class RTDETRDecoder(nn.Module): features = self.enc_output(valid_mask * feats) # bs, h*w, 256 enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc) - # dynamic anchors + static content - enc_outputs_bboxes = self.enc_bbox_head(features) + anchors # (bs, h*w, 4) # query selection # (bs, num_queries) @@ -316,22 +314,23 @@ class RTDETRDecoder(nn.Module): # (bs, num_queries) batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1) - # Unsigmoided - refer_bbox = enc_outputs_bboxes[batch_ind, topk_ind].view(bs, self.num_queries, -1) - # refer_bbox = torch.gather(enc_outputs_bboxes, 1, topk_ind.reshape(bs, self.num_queries).unsqueeze(-1).repeat(1, 1, 4)) + # (bs, num_queries, 256) + top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1) + # (bs, num_queries, 4) + top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1) + + # dynamic anchors + static content + refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors enc_bboxes = refer_bbox.sigmoid() if dn_bbox is not None: refer_bbox = torch.cat([dn_bbox, refer_bbox], 1) - if self.training: - refer_bbox = refer_bbox.detach() enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1) - if self.learnt_init_query: - embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) - else: - embeddings = features[batch_ind, topk_ind].view(bs, self.num_queries, -1) - if self.training: + embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features + if self.training: + refer_bbox = refer_bbox.detach() + if not self.learnt_init_query: embeddings = embeddings.detach() if dn_embed is not None: embeddings = torch.cat([dn_embed, embeddings], 1)