Integration of v8 segmentation (#107)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -3,7 +3,7 @@
|
||||
|
||||
# Task and Mode
|
||||
task: "classify" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case
|
||||
mode: "train" # choice=['train', 'val', 'infer']
|
||||
mode: "train" # choice=['train', 'val', 'predict']
|
||||
|
||||
# Train settings -------------------------------------------------------------------------------------------------------
|
||||
model: null # i.e. yolov5s.pt, yolo.yaml
|
||||
|
@ -86,7 +86,8 @@ class TaskAlignedAssigner(nn.Module):
|
||||
if self.n_max_boxes == 0:
|
||||
device = gt_bboxes.device
|
||||
return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device),
|
||||
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device))
|
||||
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device))
|
||||
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
|
||||
mask_gt)
|
||||
@ -103,7 +104,7 @@ class TaskAlignedAssigner(nn.Module):
|
||||
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
|
||||
target_scores = target_scores * norm_align_metric
|
||||
|
||||
return target_labels, target_bboxes, target_scores, fg_mask.bool()
|
||||
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
||||
|
||||
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
||||
# get anchor_align metric, (b, max_num_obj, h*w)
|
||||
@ -146,9 +147,6 @@ class TaskAlignedAssigner(nn.Module):
|
||||
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
|
||||
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
|
||||
# filter invalid bboxes
|
||||
# assigned topk should be unique, this is for dealing with empty labels
|
||||
# since empty labels will generate index `0` through `F.one_hot`
|
||||
# NOTE: but what if the topk_idxs include `0`?
|
||||
is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
|
||||
return is_in_topk.to(metrics.dtype)
|
||||
|
||||
|
Reference in New Issue
Block a user