From 9d1e5567de48453f168013ff1032810bd95d39fe Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Tue, 20 Jun 2023 23:59:04 +0800 Subject: [PATCH] [RTDETR]Fix val loss (#3280) --- ultralytics/nn/modules/head.py | 10 +++++++--- ultralytics/nn/tasks.py | 2 +- ultralytics/vit/rtdetr/predict.py | 4 ++-- ultralytics/vit/rtdetr/val.py | 5 ++--- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 8b55fd0..14d6b4f 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -158,6 +158,7 @@ class Classify(nn.Module): class RTDETRDecoder(nn.Module): + export = False # export mode def __init__( self, @@ -246,9 +247,12 @@ class RTDETRDecoder(nn.Module): self.dec_score_head, self.query_pos_head, attn_mask=attn_mask) - if not self.training: - dec_scores = dec_scores.sigmoid_() - return dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta + x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta + if self.training: + return x + # (bs, 300, 4+nc) + y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1) + return y if self.export else (y, x) def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2): anchors = [] diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 3c2ba06..e98a5fc 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -432,7 +432,7 @@ class RTDETRDetectionModel(DetectionModel): 'gt_groups': gt_groups} preds = self.predict(img, batch=targets) if preds is None else preds - dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds + dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1] if dn_meta is None: dn_bboxes, dn_scores = None, None else: diff --git a/ultralytics/vit/rtdetr/predict.py b/ultralytics/vit/rtdetr/predict.py index 78219b2..77c02c2 100644 --- a/ultralytics/vit/rtdetr/predict.py +++ b/ultralytics/vit/rtdetr/predict.py @@ -12,8 +12,8 @@ class RTDETRPredictor(BasePredictor): def postprocess(self, preds, img, orig_imgs): """Postprocess predictions and returns a list of Results objects.""" - bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc) - bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0) + nd = preds[0].shape[-1] + bboxes, scores = preds[0].split((4, nd - 4), dim=-1) results = [] for i, bbox in enumerate(bboxes): # (300, 4) bbox = ops.xywh2xyxy(bbox) diff --git a/ultralytics/vit/rtdetr/val.py b/ultralytics/vit/rtdetr/val.py index 57376a6..682296e 100644 --- a/ultralytics/vit/rtdetr/val.py +++ b/ultralytics/vit/rtdetr/val.py @@ -89,9 +89,8 @@ class RTDETRValidator(DetectionValidator): def postprocess(self, preds): """Apply Non-maximum suppression to prediction outputs.""" - bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc) - bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0) # (bs, 300, 4) - bs = len(bboxes) + bs, _, nd = preds[0].shape + bboxes, scores = preds[0].split((4, nd - 4), dim=-1) outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs for i, bbox in enumerate(bboxes): # (300, 4) bbox = ops.xywh2xyxy(bbox)