fix overlap_mask (#651)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Laughing 2 years ago committed by GitHub
parent 15b3b0365a
commit dc9502c700
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -145,7 +145,9 @@ class YOLODataset(BaseDataset):
normalize=True, normalize=True,
return_mask=self.use_segments, return_mask=self.use_segments,
return_keypoint=self.use_keypoints, return_keypoint=self.use_keypoints,
batch_idx=True)) batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask))
return transforms return transforms
def close_mosaic(self, hyp): def close_mosaic(self, hyp):
@ -155,7 +157,9 @@ class YOLODataset(BaseDataset):
normalize=True, normalize=True,
return_mask=self.use_segments, return_mask=self.use_segments,
return_keypoint=self.use_keypoints, return_keypoint=self.use_keypoints,
batch_idx=True)) batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask))
def update_labels_info(self, label): def update_labels_info(self, label):
"""custom your label format here""" """custom your label format here"""

@ -110,11 +110,11 @@ class SegLoss(Loss):
target_scores, target_scores_sum, fg_mask) target_scores, target_scores_sum, fg_mask)
for i in range(batch_size): for i in range(batch_size):
if fg_mask[i].sum(): if fg_mask[i].sum():
mask_idx = target_gt_idx[i][fg_mask[i]] + 1 mask_idx = target_gt_idx[i][fg_mask[i]]
if self.overlap: if self.overlap:
gt_mask = torch.where(masks[[i]] == mask_idx.view(-1, 1, 1), 1.0, 0.0) gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
else: else:
gt_mask = masks[batch_idx == i][mask_idx] gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]] xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
marea = xyxy2xywh(xyxyn)[:, 2:].prod(1) marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device) mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)

Loading…
Cancel
Save