From ff211f403729fe52d1fbf96fad05f162a6a3cddb Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Mon, 15 May 2023 18:09:34 +0800 Subject: [PATCH] fix loss (#2614) --- ultralytics/yolo/utils/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultralytics/yolo/utils/loss.py b/ultralytics/yolo/utils/loss.py index 50c3860..73aba68 100644 --- a/ultralytics/yolo/utils/loss.py +++ b/ultralytics/yolo/utils/loss.py @@ -34,7 +34,7 @@ class BboxLoss(nn.Module): def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): """IoU loss.""" - weight = target_scores.sum(-1)[fg_mask] + weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1) iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True) loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum