Fix RTDETR val_batch_pred (#3392)

single_channel
Laughing 1 year ago committed by GitHub
parent 4e08e12256
commit e93a5fbff6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -91,6 +91,7 @@ class RTDETRValidator(DetectionValidator):
"""Apply Non-maximum suppression to prediction outputs."""
bs, _, nd = preds[0].shape
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
bboxes *= self.args.imgsz
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
for i, bbox in enumerate(bboxes): # (300, 4)
bbox = ops.xywh2xyxy(bbox)
@ -126,8 +127,8 @@ class RTDETRValidator(DetectionValidator):
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
predn[..., [0, 2]] *= shape[1] # native-space pred
predn[..., [1, 3]] *= shape[0] # native-space pred
predn[..., [0, 2]] *= shape[1] / self.args.imgsz # native-space pred
predn[..., [1, 3]] *= shape[0] / self.args.imgsz # native-space pred
# Evaluate
if nl:

Loading…
Cancel
Save