From 51d8cfa9c37b7b2b98b3d3ec5a6f1a9ff6b38359 Mon Sep 17 00:00:00 2001 From: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Date: Sat, 24 Jun 2023 22:34:24 +0800 Subject: [PATCH] Fix RT-DETR exported onnx model (#3317) Co-authored-by: Glenn Jocher Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/nn/modules/head.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 14d6b4f..afad816 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -270,7 +270,7 @@ class RTDETRDecoder(nn.Module): anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4) valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1 anchors = torch.log(anchors / (1 - anchors)) - anchors = torch.where(valid_mask, anchors, torch.inf) + anchors = anchors.masked_fill(~valid_mask, float('inf')) return anchors, valid_mask def _get_encoder_input(self, x): @@ -294,7 +294,7 @@ class RTDETRDecoder(nn.Module): bs = len(feats) # prepare input for decoder anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device) - features = self.enc_output(torch.where(valid_mask, feats, 0)) # bs, h*w, 256 + features = self.enc_output(valid_mask * feats) # bs, h*w, 256 enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc) # dynamic anchors + static content