[RTDETR]Fix val loss (#3280)

This commit is contained in:
Laughing
2023-06-20 23:59:04 +08:00
committed by GitHub
parent d8701b42ca
commit 9d1e5567de
4 changed files with 12 additions and 9 deletions

View File

@ -158,6 +158,7 @@ class Classify(nn.Module):
class RTDETRDecoder(nn.Module):
export = False # export mode
def __init__(
self,
@ -246,9 +247,12 @@ class RTDETRDecoder(nn.Module):
self.dec_score_head,
self.query_pos_head,
attn_mask=attn_mask)
if not self.training:
dec_scores = dec_scores.sigmoid_()
return dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
if self.training:
return x
# (bs, 300, 4+nc)
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
return y if self.export else (y, x)
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
anchors = []

View File

@ -432,7 +432,7 @@ class RTDETRDetectionModel(DetectionModel):
'gt_groups': gt_groups}
preds = self.predict(img, batch=targets) if preds is None else preds
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
if dn_meta is None:
dn_bboxes, dn_scores = None, None
else: