[RTDETR]Fix val loss (#3280)

single_channel
Laughing 1 year ago committed by GitHub
parent d8701b42ca
commit 9d1e5567de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -158,6 +158,7 @@ class Classify(nn.Module):
class RTDETRDecoder(nn.Module): class RTDETRDecoder(nn.Module):
export = False # export mode
def __init__( def __init__(
self, self,
@ -246,9 +247,12 @@ class RTDETRDecoder(nn.Module):
self.dec_score_head, self.dec_score_head,
self.query_pos_head, self.query_pos_head,
attn_mask=attn_mask) attn_mask=attn_mask)
if not self.training: x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
dec_scores = dec_scores.sigmoid_() if self.training:
return dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta return x
# (bs, 300, 4+nc)
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
return y if self.export else (y, x)
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2): def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
anchors = [] anchors = []

@ -432,7 +432,7 @@ class RTDETRDetectionModel(DetectionModel):
'gt_groups': gt_groups} 'gt_groups': gt_groups}
preds = self.predict(img, batch=targets) if preds is None else preds preds = self.predict(img, batch=targets) if preds is None else preds
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
if dn_meta is None: if dn_meta is None:
dn_bboxes, dn_scores = None, None dn_bboxes, dn_scores = None, None
else: else:

@ -12,8 +12,8 @@ class RTDETRPredictor(BasePredictor):
def postprocess(self, preds, img, orig_imgs): def postprocess(self, preds, img, orig_imgs):
"""Postprocess predictions and returns a list of Results objects.""" """Postprocess predictions and returns a list of Results objects."""
bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc) nd = preds[0].shape[-1]
bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0) bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
results = [] results = []
for i, bbox in enumerate(bboxes): # (300, 4) for i, bbox in enumerate(bboxes): # (300, 4)
bbox = ops.xywh2xyxy(bbox) bbox = ops.xywh2xyxy(bbox)

@ -89,9 +89,8 @@ class RTDETRValidator(DetectionValidator):
def postprocess(self, preds): def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs.""" """Apply Non-maximum suppression to prediction outputs."""
bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc) bs, _, nd = preds[0].shape
bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0) # (bs, 300, 4) bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
bs = len(bboxes)
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
for i, bbox in enumerate(bboxes): # (300, 4) for i, bbox in enumerate(bboxes): # (300, 4)
bbox = ops.xywh2xyxy(bbox) bbox = ops.xywh2xyxy(bbox)

Loading…
Cancel
Save