Fix `save_hybrid` (#4245)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Laughing 1 year ago committed by GitHub
parent 7dfdb63cde
commit e9c9b82c42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,6 +26,7 @@ class DetectionValidator(BaseValidator):
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95 self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
self.niou = self.iouv.numel() self.niou = self.iouv.numel()
self.lb = [] # for autolabelling
def preprocess(self, batch): def preprocess(self, batch):
"""Preprocesses batch of images for YOLO training.""" """Preprocesses batch of images for YOLO training."""
@ -34,9 +35,13 @@ class DetectionValidator(BaseValidator):
for k in ['batch_idx', 'cls', 'bboxes']: for k in ['batch_idx', 'cls', 'bboxes']:
batch[k] = batch[k].to(self.device) batch[k] = batch[k].to(self.device)
nb = len(batch['img']) if self.args.save_hybrid:
self.lb = [torch.cat([batch['cls'], batch['bboxes']], dim=-1)[batch['batch_idx'] == i] height, width = batch['img'].shape[2:]
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling nb = len(batch['img'])
bboxes = batch['bboxes'] * torch.tensor((width, height, width, height), device=self.device)
self.lb = [
torch.cat([batch['cls'][batch['batch_idx'] == i], bboxes[batch['batch_idx'] == i]], dim=-1)
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
return batch return batch

@ -225,8 +225,8 @@ def non_max_suppression(
# Cat apriori labels if autolabelling # Cat apriori labels if autolabelling
if labels and len(labels[xi]): if labels and len(labels[xi]):
lb = labels[xi] lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device) v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
v[:, :4] = lb[:, 1:5] # box v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
x = torch.cat((x, v), 0) x = torch.cat((x, v), 0)

Loading…
Cancel
Save