Fix RT-DETR exported onnx model (#3317)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Kayzwer 1 year ago committed by GitHub
parent 2f58b5821a
commit 51d8cfa9c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -270,7 +270,7 @@ class RTDETRDecoder(nn.Module):
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4) anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1 valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
anchors = torch.log(anchors / (1 - anchors)) anchors = torch.log(anchors / (1 - anchors))
anchors = torch.where(valid_mask, anchors, torch.inf) anchors = anchors.masked_fill(~valid_mask, float('inf'))
return anchors, valid_mask return anchors, valid_mask
def _get_encoder_input(self, x): def _get_encoder_input(self, x):
@ -294,7 +294,7 @@ class RTDETRDecoder(nn.Module):
bs = len(feats) bs = len(feats)
# prepare input for decoder # prepare input for decoder
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device) anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
features = self.enc_output(torch.where(valid_mask, feats, 0)) # bs, h*w, 256 features = self.enc_output(valid_mask * feats) # bs, h*w, 256
enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc) enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
# dynamic anchors + static content # dynamic anchors + static content

Loading…
Cancel
Save