diff --git a/ultralytics/yolo/utils/loss.py b/ultralytics/yolo/utils/loss.py index 50c3860..73aba68 100644 --- a/ultralytics/yolo/utils/loss.py +++ b/ultralytics/yolo/utils/loss.py @@ -34,7 +34,7 @@ class BboxLoss(nn.Module): def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): """IoU loss.""" - weight = target_scores.sum(-1)[fg_mask] + weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1) iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True) loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum