From e9c9b82c426d780a04cb881e6653cd12aa586294 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Thu, 10 Aug 2023 00:51:21 +0800 Subject: [PATCH] Fix `save_hybrid` (#4245) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/models/yolo/detect/val.py | 11 ++++++++--- ultralytics/utils/ops.py | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/ultralytics/models/yolo/detect/val.py b/ultralytics/models/yolo/detect/val.py index f4109f5..d37aa2b 100644 --- a/ultralytics/models/yolo/detect/val.py +++ b/ultralytics/models/yolo/detect/val.py @@ -26,6 +26,7 @@ class DetectionValidator(BaseValidator): self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95 self.niou = self.iouv.numel() + self.lb = [] # for autolabelling def preprocess(self, batch): """Preprocesses batch of images for YOLO training.""" @@ -34,9 +35,13 @@ class DetectionValidator(BaseValidator): for k in ['batch_idx', 'cls', 'bboxes']: batch[k] = batch[k].to(self.device) - nb = len(batch['img']) - self.lb = [torch.cat([batch['cls'], batch['bboxes']], dim=-1)[batch['batch_idx'] == i] - for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling + if self.args.save_hybrid: + height, width = batch['img'].shape[2:] + nb = len(batch['img']) + bboxes = batch['bboxes'] * torch.tensor((width, height, width, height), device=self.device) + self.lb = [ + torch.cat([batch['cls'][batch['batch_idx'] == i], bboxes[batch['batch_idx'] == i]], dim=-1) + for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling return batch diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index b3e75ad..5f65cdf 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -225,8 +225,8 @@ def non_max_suppression( # Cat apriori labels if autolabelling if labels and len(labels[xi]): lb = labels[xi] - v = torch.zeros((len(lb), nc + nm + 5), device=x.device) - v[:, :4] = lb[:, 1:5] # box + v = torch.zeros((len(lb), nc + nm + 4), device=x.device) + v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls x = torch.cat((x, v), 0)