default.yaml
type comments (#3237)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -22,20 +22,15 @@ class MaskDecoder(nn.Module):
|
||||
iou_head_hidden_dim: int = 256,
|
||||
) -> None:
|
||||
"""
|
||||
Predicts masks given an image and prompt embeddings, using a
|
||||
transformer architecture.
|
||||
Predicts masks given an image and prompt embeddings, using a transformer architecture.
|
||||
|
||||
Arguments:
|
||||
transformer_dim (int): the channel dimension of the transformer
|
||||
transformer (nn.Module): the transformer used to predict masks
|
||||
num_multimask_outputs (int): the number of masks to predict
|
||||
when disambiguating masks
|
||||
activation (nn.Module): the type of activation to use when
|
||||
upscaling masks
|
||||
iou_head_depth (int): the depth of the MLP used to predict
|
||||
mask quality
|
||||
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
||||
used to predict mask quality
|
||||
transformer_dim (int): the channel dimension of the transformer module
|
||||
transformer (nn.Module): the transformer used to predict masks
|
||||
num_multimask_outputs (int): the number of masks to predict when disambiguating masks
|
||||
activation (nn.Module): the type of activation to use when upscaling masks
|
||||
iou_head_depth (int): the depth of the MLP used to predict mask quality
|
||||
iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
@ -71,16 +66,15 @@ class MaskDecoder(nn.Module):
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
||||
Arguments:
|
||||
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
||||
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
||||
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
||||
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
||||
multimask_output (bool): Whether to return multiple masks or a single
|
||||
mask.
|
||||
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
||||
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
||||
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
||||
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
||||
multimask_output (bool): Whether to return multiple masks or a single mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: batched predicted masks
|
||||
torch.Tensor: batched predictions of mask quality
|
||||
torch.Tensor: batched predicted masks
|
||||
torch.Tensor: batched predictions of mask quality
|
||||
"""
|
||||
masks, iou_pred = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
@ -136,9 +130,11 @@ class MaskDecoder(nn.Module):
|
||||
return masks, iou_pred
|
||||
|
||||
|
||||
# Lightly adapted from
|
||||
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Lightly adapted from
|
||||
https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -249,7 +249,7 @@ def get_cdn_group(batch,
|
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
|
||||
dn_meta = {
|
||||
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split([n for n in gt_groups], dim=1)],
|
||||
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
|
||||
'dn_num_group': num_group,
|
||||
'dn_num_split': [num_dn, num_queries]}
|
||||
|
||||
@ -258,5 +258,6 @@ def get_cdn_group(batch,
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-6):
|
||||
"""Inverse sigmoid function."""
|
||||
x = x.clip(min=0., max=1.)
|
||||
return torch.log(x / (1 - x + eps) + eps)
|
||||
|
Reference in New Issue
Block a user