|
|
@ -89,9 +89,8 @@ class RTDETRValidator(DetectionValidator):
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess(self, preds):
|
|
|
|
def postprocess(self, preds):
|
|
|
|
"""Apply Non-maximum suppression to prediction outputs."""
|
|
|
|
"""Apply Non-maximum suppression to prediction outputs."""
|
|
|
|
bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc)
|
|
|
|
bs, _, nd = preds[0].shape
|
|
|
|
bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0) # (bs, 300, 4)
|
|
|
|
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
|
|
|
bs = len(bboxes)
|
|
|
|
|
|
|
|
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
|
|
|
|
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
|
|
|
|
for i, bbox in enumerate(bboxes): # (300, 4)
|
|
|
|
for i, bbox in enumerate(bboxes): # (300, 4)
|
|
|
|
bbox = ops.xywh2xyxy(bbox)
|
|
|
|
bbox = ops.xywh2xyxy(bbox)
|
|
|
|