[RTDETR]Fix val loss (#3280)
This commit is contained in:
@ -158,6 +158,7 @@ class Classify(nn.Module):
|
||||
|
||||
|
||||
class RTDETRDecoder(nn.Module):
|
||||
export = False # export mode
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -246,9 +247,12 @@ class RTDETRDecoder(nn.Module):
|
||||
self.dec_score_head,
|
||||
self.query_pos_head,
|
||||
attn_mask=attn_mask)
|
||||
if not self.training:
|
||||
dec_scores = dec_scores.sigmoid_()
|
||||
return dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
|
||||
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
|
||||
if self.training:
|
||||
return x
|
||||
# (bs, 300, 4+nc)
|
||||
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
|
||||
return y if self.export else (y, x)
|
||||
|
||||
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
|
||||
anchors = []
|
||||
|
Reference in New Issue
Block a user