[RTDETR]Fix val loss (#3280)

This commit is contained in:
Laughing
2023-06-20 23:59:04 +08:00
committed by GitHub
parent d8701b42ca
commit 9d1e5567de
4 changed files with 12 additions and 9 deletions

View File

@ -12,8 +12,8 @@ class RTDETRPredictor(BasePredictor):
def postprocess(self, preds, img, orig_imgs):
"""Postprocess predictions and returns a list of Results objects."""
bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc)
bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0)
nd = preds[0].shape[-1]
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
results = []
for i, bbox in enumerate(bboxes): # (300, 4)
bbox = ops.xywh2xyxy(bbox)

View File

@ -89,9 +89,8 @@ class RTDETRValidator(DetectionValidator):
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc)
bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0) # (bs, 300, 4)
bs = len(bboxes)
bs, _, nd = preds[0].shape
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
for i, bbox in enumerate(bboxes): # (300, 4)
bbox = ops.xywh2xyxy(bbox)