Integration of v8 segmentation (#107)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing
2022-12-28 23:01:38 +08:00
committed by GitHub
parent 384f0ef1c6
commit 8406b49b49
16 changed files with 422 additions and 224 deletions

View File

@ -3,7 +3,7 @@
# Task and Mode
task: "classify" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case
mode: "train" # choice=['train', 'val', 'infer']
mode: "train" # choice=['train', 'val', 'predict']
# Train settings -------------------------------------------------------------------------------------------------------
model: null # i.e. yolov5s.pt, yolo.yaml

View File

@ -86,7 +86,8 @@ class TaskAlignedAssigner(nn.Module):
if self.n_max_boxes == 0:
device = gt_bboxes.device
return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device))
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device))
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
mask_gt)
@ -103,7 +104,7 @@ class TaskAlignedAssigner(nn.Module):
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
target_scores = target_scores * norm_align_metric
return target_labels, target_bboxes, target_scores, fg_mask.bool()
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
# get anchor_align metric, (b, max_num_obj, h*w)
@ -146,9 +147,6 @@ class TaskAlignedAssigner(nn.Module):
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
# filter invalid bboxes
# assigned topk should be unique, this is for dealing with empty labels
# since empty labels will generate index `0` through `F.one_hot`
# NOTE: but what if the topk_idxs include `0`?
is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
return is_in_topk.to(metrics.dtype)