# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import Any , Optional , Tuple , Type
import numpy as np
import torch
import torch . nn as nn
import torch . nn . functional as F
from ultralytics . nn . modules import LayerNorm2d , MLPBlock
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT ( nn . Module ) :
def __init__ (
self ,
img_size : int = 1024 ,
patch_size : int = 16 ,
in_chans : int = 3 ,
embed_dim : int = 768 ,
depth : int = 12 ,
num_heads : int = 12 ,
mlp_ratio : float = 4.0 ,
out_chans : int = 256 ,
qkv_bias : bool = True ,
norm_layer : Type [ nn . Module ] = nn . LayerNorm ,
act_layer : Type [ nn . Module ] = nn . GELU ,
use_abs_pos : bool = True ,
use_rel_pos : bool = False ,
rel_pos_zero_init : bool = True ,
window_size : int = 0 ,
global_attn_indexes : Tuple [ int , . . . ] = ( ) ,
) - > None :
"""
Args :
img_size ( int ) : Input image size .
patch_size ( int ) : Patch size .
in_chans ( int ) : Number of input image channels .
embed_dim ( int ) : Patch embedding dimension .
depth ( int ) : Depth of ViT .
num_heads ( int ) : Number of attention heads in each ViT block .
mlp_ratio ( float ) : Ratio of mlp hidden dim to embedding dim .
qkv_bias ( bool ) : If True , add a learnable bias to query , key , value .
norm_layer ( nn . Module ) : Normalization layer .
act_layer ( nn . Module ) : Activation layer .
use_abs_pos ( bool ) : If True , use absolute positional embeddings .
use_rel_pos ( bool ) : If True , add relative positional embeddings to the attention map .
rel_pos_zero_init ( bool ) : If True , zero initialize relative positional parameters .
window_size ( int ) : Window size for window attention blocks .
global_attn_indexes ( list ) : Indexes for blocks using global attention .
"""
super ( ) . __init__ ( )
self . img_size = img_size
self . patch_embed = PatchEmbed (
kernel_size = ( patch_size , patch_size ) ,
stride = ( patch_size , patch_size ) ,
in_chans = in_chans ,
embed_dim = embed_dim ,
)
self . pos_embed : Optional [ nn . Parameter ] = None
if use_abs_pos :
# Initialize absolute positional embedding with pretrain image size.
self . pos_embed = nn . Parameter ( torch . zeros ( 1 , img_size / / patch_size , img_size / / patch_size , embed_dim ) )
self . blocks = nn . ModuleList ( )
for i in range ( depth ) :
block = Block (
dim = embed_dim ,
num_heads = num_heads ,
mlp_ratio = mlp_ratio ,
qkv_bias = qkv_bias ,
norm_layer = norm_layer ,
act_layer = act_layer ,
use_rel_pos = use_rel_pos ,
rel_pos_zero_init = rel_pos_zero_init ,
window_size = window_size if i not in global_attn_indexes else 0 ,
input_size = ( img_size / / patch_size , img_size / / patch_size ) ,
)
self . blocks . append ( block )
self . neck = nn . Sequential (
nn . Conv2d (
embed_dim ,
out_chans ,
kernel_size = 1 ,
bias = False ,
) ,
LayerNorm2d ( out_chans ) ,
nn . Conv2d (
out_chans ,
out_chans ,
kernel_size = 3 ,
padding = 1 ,
bias = False ,
) ,
LayerNorm2d ( out_chans ) ,
)
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
x = self . patch_embed ( x )
if self . pos_embed is not None :
x = x + self . pos_embed
for blk in self . blocks :
x = blk ( x )
x = self . neck ( x . permute ( 0 , 3 , 1 , 2 ) )
return x
class PromptEncoder ( nn . Module ) :
def __init__ (
self ,
embed_dim : int ,
image_embedding_size : Tuple [ int , int ] ,
input_image_size : Tuple [ int , int ] ,
mask_in_chans : int ,
activation : Type [ nn . Module ] = nn . GELU ,
) - > None :
"""
Encodes prompts for input to SAM ' s mask decoder.
Arguments :
embed_dim ( int ) : The prompts ' embedding dimension
image_embedding_size ( tuple ( int , int ) ) : The spatial size of the
image embedding , as ( H , W ) .
input_image_size ( int ) : The padded size of the image as input
to the image encoder , as ( H , W ) .
mask_in_chans ( int ) : The number of hidden channels used for
encoding input masks .
activation ( nn . Module ) : The activation to use when encoding
input masks .
"""
super ( ) . __init__ ( )
self . embed_dim = embed_dim
self . input_image_size = input_image_size
self . image_embedding_size = image_embedding_size
self . pe_layer = PositionEmbeddingRandom ( embed_dim / / 2 )
self . num_point_embeddings : int = 4 # pos/neg point + 2 box corners
point_embeddings = [ nn . Embedding ( 1 , embed_dim ) for _ in range ( self . num_point_embeddings ) ]
self . point_embeddings = nn . ModuleList ( point_embeddings )
self . not_a_point_embed = nn . Embedding ( 1 , embed_dim )
self . mask_input_size = ( 4 * image_embedding_size [ 0 ] , 4 * image_embedding_size [ 1 ] )
self . mask_downscaling = nn . Sequential (
nn . Conv2d ( 1 , mask_in_chans / / 4 , kernel_size = 2 , stride = 2 ) ,
LayerNorm2d ( mask_in_chans / / 4 ) ,
activation ( ) ,
nn . Conv2d ( mask_in_chans / / 4 , mask_in_chans , kernel_size = 2 , stride = 2 ) ,
LayerNorm2d ( mask_in_chans ) ,
activation ( ) ,
nn . Conv2d ( mask_in_chans , embed_dim , kernel_size = 1 ) ,
)
self . no_mask_embed = nn . Embedding ( 1 , embed_dim )
def get_dense_pe ( self ) - > torch . Tensor :
"""
Returns the positional encoding used to encode point prompts ,
applied to a dense set of points the shape of the image encoding .
Returns :
torch . Tensor : Positional encoding with shape
1 x ( embed_dim ) x ( embedding_h ) x ( embedding_w )
"""
return self . pe_layer ( self . image_embedding_size ) . unsqueeze ( 0 )
def _embed_points (
self ,
points : torch . Tensor ,
labels : torch . Tensor ,
pad : bool ,
) - > torch . Tensor :
""" Embeds point prompts. """
points = points + 0.5 # Shift to center of pixel
if pad :
padding_point = torch . zeros ( ( points . shape [ 0 ] , 1 , 2 ) , device = points . device )
padding_label = - torch . ones ( ( labels . shape [ 0 ] , 1 ) , device = labels . device )
points = torch . cat ( [ points , padding_point ] , dim = 1 )
labels = torch . cat ( [ labels , padding_label ] , dim = 1 )
point_embedding = self . pe_layer . forward_with_coords ( points , self . input_image_size )
point_embedding [ labels == - 1 ] = 0.0
point_embedding [ labels == - 1 ] + = self . not_a_point_embed . weight
point_embedding [ labels == 0 ] + = self . point_embeddings [ 0 ] . weight
point_embedding [ labels == 1 ] + = self . point_embeddings [ 1 ] . weight
return point_embedding
def _embed_boxes ( self , boxes : torch . Tensor ) - > torch . Tensor :
""" Embeds box prompts. """
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes . reshape ( - 1 , 2 , 2 )
corner_embedding = self . pe_layer . forward_with_coords ( coords , self . input_image_size )
corner_embedding [ : , 0 , : ] + = self . point_embeddings [ 2 ] . weight
corner_embedding [ : , 1 , : ] + = self . point_embeddings [ 3 ] . weight
return corner_embedding
def _embed_masks ( self , masks : torch . Tensor ) - > torch . Tensor :
""" Embeds mask inputs. """
return self . mask_downscaling ( masks )
def _get_batch_size (
self ,
points : Optional [ Tuple [ torch . Tensor , torch . Tensor ] ] ,
boxes : Optional [ torch . Tensor ] ,
masks : Optional [ torch . Tensor ] ,
) - > int :
"""
Gets the batch size of the output given the batch size of the input prompts .
"""
if points is not None :
return points [ 0 ] . shape [ 0 ]
elif boxes is not None :
return boxes . shape [ 0 ]
elif masks is not None :
return masks . shape [ 0 ]
else :
return 1
def _get_device ( self ) - > torch . device :
return self . point_embeddings [ 0 ] . weight . device
def forward (
self ,
points : Optional [ Tuple [ torch . Tensor , torch . Tensor ] ] ,
boxes : Optional [ torch . Tensor ] ,
masks : Optional [ torch . Tensor ] ,
) - > Tuple [ torch . Tensor , torch . Tensor ] :
"""
Embeds different types of prompts , returning both sparse and dense
embeddings .
Arguments :
points ( tuple ( torch . Tensor , torch . Tensor ) , None ) : point coordinates
and labels to embed .
boxes ( torch . Tensor , None ) : boxes to embed
masks ( torch . Tensor , None ) : masks to embed
Returns :
torch . Tensor : sparse embeddings for the points and boxes , with shape
BxNx ( embed_dim ) , where N is determined by the number of input points
and boxes .
torch . Tensor : dense embeddings for the masks , in the shape
Bx ( embed_dim ) x ( embed_H ) x ( embed_W )
"""
bs = self . _get_batch_size ( points , boxes , masks )
sparse_embeddings = torch . empty ( ( bs , 0 , self . embed_dim ) , device = self . _get_device ( ) )
if points is not None :
coords , labels = points
point_embeddings = self . _embed_points ( coords , labels , pad = ( boxes is None ) )
sparse_embeddings = torch . cat ( [ sparse_embeddings , point_embeddings ] , dim = 1 )
if boxes is not None :
box_embeddings = self . _embed_boxes ( boxes )
sparse_embeddings = torch . cat ( [ sparse_embeddings , box_embeddings ] , dim = 1 )
if masks is not None :
dense_embeddings = self . _embed_masks ( masks )
else :
dense_embeddings = self . no_mask_embed . weight . reshape ( 1 , - 1 , 1 ,
1 ) . expand ( bs , - 1 , self . image_embedding_size [ 0 ] ,
self . image_embedding_size [ 1 ] )
return sparse_embeddings , dense_embeddings
class PositionEmbeddingRandom ( nn . Module ) :
"""
Positional encoding using random spatial frequencies .
"""
def __init__ ( self , num_pos_feats : int = 64 , scale : Optional [ float ] = None ) - > None :
super ( ) . __init__ ( )
if scale is None or scale < = 0.0 :
scale = 1.0
self . register_buffer (
' positional_encoding_gaussian_matrix ' ,
scale * torch . randn ( ( 2 , num_pos_feats ) ) ,
)
def _pe_encoding ( self , coords : torch . Tensor ) - > torch . Tensor :
""" Positionally encode points that are normalized to [0,1]. """
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self . positional_encoding_gaussian_matrix
coords = 2 * np . pi * coords
# outputs d_1 x ... x d_n x C shape
return torch . cat ( [ torch . sin ( coords ) , torch . cos ( coords ) ] , dim = - 1 )
def forward ( self , size : Tuple [ int , int ] ) - > torch . Tensor :
""" Generate positional encoding for a grid of the specified size. """
h , w = size
device : Any = self . positional_encoding_gaussian_matrix . device
grid = torch . ones ( ( h , w ) , device = device , dtype = torch . float32 )
y_embed = grid . cumsum ( dim = 0 ) - 0.5
x_embed = grid . cumsum ( dim = 1 ) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self . _pe_encoding ( torch . stack ( [ x_embed , y_embed ] , dim = - 1 ) )
return pe . permute ( 2 , 0 , 1 ) # C x H x W
def forward_with_coords ( self , coords_input : torch . Tensor , image_size : Tuple [ int , int ] ) - > torch . Tensor :
""" Positionally encode points that are not normalized to [0,1]. """
coords = coords_input . clone ( )
coords [ : , : , 0 ] = coords [ : , : , 0 ] / image_size [ 1 ]
coords [ : , : , 1 ] = coords [ : , : , 1 ] / image_size [ 0 ]
return self . _pe_encoding ( coords . to ( torch . float ) ) # B x N x C
class Block ( nn . Module ) :
""" Transformer blocks with support of window attention and residual propagation blocks """
def __init__ (
self ,
dim : int ,
num_heads : int ,
mlp_ratio : float = 4.0 ,
qkv_bias : bool = True ,
norm_layer : Type [ nn . Module ] = nn . LayerNorm ,
act_layer : Type [ nn . Module ] = nn . GELU ,
use_rel_pos : bool = False ,
rel_pos_zero_init : bool = True ,
window_size : int = 0 ,
input_size : Optional [ Tuple [ int , int ] ] = None ,
) - > None :
"""
Args :
dim ( int ) : Number of input channels .
num_heads ( int ) : Number of attention heads in each ViT block .
mlp_ratio ( float ) : Ratio of mlp hidden dim to embedding dim .
qkv_bias ( bool ) : If True , add a learnable bias to query , key , value .
norm_layer ( nn . Module ) : Normalization layer .
act_layer ( nn . Module ) : Activation layer .
use_rel_pos ( bool ) : If True , add relative positional embeddings to the attention map .
rel_pos_zero_init ( bool ) : If True , zero initialize relative positional parameters .
window_size ( int ) : Window size for window attention blocks . If it equals 0 , then
use global attention .
input_size ( tuple ( int , int ) , None ) : Input resolution for calculating the relative
positional parameter size .
"""
super ( ) . __init__ ( )
self . norm1 = norm_layer ( dim )
self . attn = Attention (
dim ,
num_heads = num_heads ,
qkv_bias = qkv_bias ,
use_rel_pos = use_rel_pos ,
rel_pos_zero_init = rel_pos_zero_init ,
input_size = input_size if window_size == 0 else ( window_size , window_size ) ,
)
self . norm2 = norm_layer ( dim )
self . mlp = MLPBlock ( embedding_dim = dim , mlp_dim = int ( dim * mlp_ratio ) , act = act_layer )
self . window_size = window_size
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
shortcut = x
x = self . norm1 ( x )
# Window partition
if self . window_size > 0 :
H , W = x . shape [ 1 ] , x . shape [ 2 ]
x , pad_hw = window_partition ( x , self . window_size )
x = self . attn ( x )
# Reverse window partition
if self . window_size > 0 :
x = window_unpartition ( x , self . window_size , pad_hw , ( H , W ) )
x = shortcut + x
x = x + self . mlp ( self . norm2 ( x ) )
return x
class Attention ( nn . Module ) :
""" Multi-head Attention block with relative position embeddings. """
def __init__ (
self ,
dim : int ,
num_heads : int = 8 ,
qkv_bias : bool = True ,
use_rel_pos : bool = False ,
rel_pos_zero_init : bool = True ,
input_size : Optional [ Tuple [ int , int ] ] = None ,
) - > None :
"""
Args :
dim ( int ) : Number of input channels .
num_heads ( int ) : Number of attention heads .
qkv_bias ( bool ) : If True , add a learnable bias to query , key , value .
rel_pos_zero_init ( bool ) : If True , zero initialize relative positional parameters .
input_size ( tuple ( int , int ) , None ) : Input resolution for calculating the relative
positional parameter size .
"""
super ( ) . __init__ ( )
self . num_heads = num_heads
head_dim = dim / / num_heads
self . scale = head_dim * * - 0.5
self . qkv = nn . Linear ( dim , dim * 3 , bias = qkv_bias )
self . proj = nn . Linear ( dim , dim )
self . use_rel_pos = use_rel_pos
if self . use_rel_pos :
assert ( input_size is not None ) , ' Input size must be provided if using relative positional encoding. '
# initialize relative positional embeddings
self . rel_pos_h = nn . Parameter ( torch . zeros ( 2 * input_size [ 0 ] - 1 , head_dim ) )
self . rel_pos_w = nn . Parameter ( torch . zeros ( 2 * input_size [ 1 ] - 1 , head_dim ) )
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
B , H , W , _ = x . shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self . qkv ( x ) . reshape ( B , H * W , 3 , self . num_heads , - 1 ) . permute ( 2 , 0 , 3 , 1 , 4 )
# q, k, v with shape (B * nHead, H * W, C)
q , k , v = qkv . reshape ( 3 , B * self . num_heads , H * W , - 1 ) . unbind ( 0 )
attn = ( q * self . scale ) @ k . transpose ( - 2 , - 1 )
if self . use_rel_pos :
attn = add_decomposed_rel_pos ( attn , q , self . rel_pos_h , self . rel_pos_w , ( H , W ) , ( H , W ) )
attn = attn . softmax ( dim = - 1 )
x = ( attn @ v ) . view ( B , self . num_heads , H , W , - 1 ) . permute ( 0 , 2 , 3 , 1 , 4 ) . reshape ( B , H , W , - 1 )
x = self . proj ( x )
return x
def window_partition ( x : torch . Tensor , window_size : int ) - > Tuple [ torch . Tensor , Tuple [ int , int ] ] :
"""
Partition into non - overlapping windows with padding if needed .
Args :
x ( tensor ) : input tokens with [ B , H , W , C ] .
window_size ( int ) : window size .
Returns :
windows : windows after partition with [ B * num_windows , window_size , window_size , C ] .
( Hp , Wp ) : padded height and width before partition
"""
B , H , W , C = x . shape
pad_h = ( window_size - H % window_size ) % window_size
pad_w = ( window_size - W % window_size ) % window_size
if pad_h > 0 or pad_w > 0 :
x = F . pad ( x , ( 0 , 0 , 0 , pad_w , 0 , pad_h ) )
Hp , Wp = H + pad_h , W + pad_w
x = x . view ( B , Hp / / window_size , window_size , Wp / / window_size , window_size , C )
windows = x . permute ( 0 , 1 , 3 , 2 , 4 , 5 ) . contiguous ( ) . view ( - 1 , window_size , window_size , C )
return windows , ( Hp , Wp )
def window_unpartition ( windows : torch . Tensor , window_size : int , pad_hw : Tuple [ int , int ] ,
hw : Tuple [ int , int ] ) - > torch . Tensor :
"""
Window unpartition into original sequences and removing padding .
Args :
windows ( tensor ) : input tokens with [ B * num_windows , window_size , window_size , C ] .
window_size ( int ) : window size .
pad_hw ( Tuple ) : padded height and width ( Hp , Wp ) .
hw ( Tuple ) : original height and width ( H , W ) before padding .
Returns :
x : unpartitioned sequences with [ B , H , W , C ] .
"""
Hp , Wp = pad_hw
H , W = hw
B = windows . shape [ 0 ] / / ( Hp * Wp / / window_size / / window_size )
x = windows . view ( B , Hp / / window_size , Wp / / window_size , window_size , window_size , - 1 )
x = x . permute ( 0 , 1 , 3 , 2 , 4 , 5 ) . contiguous ( ) . view ( B , Hp , Wp , - 1 )
if Hp > H or Wp > W :
x = x [ : , : H , : W , : ] . contiguous ( )
return x
def get_rel_pos ( q_size : int , k_size : int , rel_pos : torch . Tensor ) - > torch . Tensor :
"""
Get relative positional embeddings according to the relative positions of
query and key sizes .
Args :
q_size ( int ) : size of query q .
k_size ( int ) : size of key k .
rel_pos ( Tensor ) : relative position embeddings ( L , C ) .
Returns :
Extracted positional embeddings according to relative positions .
"""
max_rel_dist = int ( 2 * max ( q_size , k_size ) - 1 )
# Interpolate rel pos if needed.
if rel_pos . shape [ 0 ] != max_rel_dist :
# Interpolate rel pos.
rel_pos_resized = F . interpolate (
rel_pos . reshape ( 1 , rel_pos . shape [ 0 ] , - 1 ) . permute ( 0 , 2 , 1 ) ,
size = max_rel_dist ,
mode = ' linear ' ,
)
rel_pos_resized = rel_pos_resized . reshape ( - 1 , max_rel_dist ) . permute ( 1 , 0 )
else :
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch . arange ( q_size ) [ : , None ] * max ( k_size / q_size , 1.0 )
k_coords = torch . arange ( k_size ) [ None , : ] * max ( q_size / k_size , 1.0 )
relative_coords = ( q_coords - k_coords ) + ( k_size - 1 ) * max ( q_size / k_size , 1.0 )
return rel_pos_resized [ relative_coords . long ( ) ]
def add_decomposed_rel_pos (
attn : torch . Tensor ,
q : torch . Tensor ,
rel_pos_h : torch . Tensor ,
rel_pos_w : torch . Tensor ,
q_size : Tuple [ int , int ] ,
k_size : Tuple [ int , int ] ,
) - > torch . Tensor :
"""
Calculate decomposed Relative Positional Embeddings from : paper : ` mvitv2 ` .
https : / / github . com / facebookresearch / mvit / blob / 19786631e330 df9f3622e5402b4a419a263a2c80 / mvit / models / attention . py # noqa B950
Args :
attn ( Tensor ) : attention map .
q ( Tensor ) : query q in the attention layer with shape ( B , q_h * q_w , C ) .
rel_pos_h ( Tensor ) : relative position embeddings ( Lh , C ) for height axis .
rel_pos_w ( Tensor ) : relative position embeddings ( Lw , C ) for width axis .
q_size ( Tuple ) : spatial sequence size of query q with ( q_h , q_w ) .
k_size ( Tuple ) : spatial sequence size of key k with ( k_h , k_w ) .
Returns :
attn ( Tensor ) : attention map with added relative positional embeddings .
"""
q_h , q_w = q_size
k_h , k_w = k_size
Rh = get_rel_pos ( q_h , k_h , rel_pos_h )
Rw = get_rel_pos ( q_w , k_w , rel_pos_w )
B , _ , dim = q . shape
r_q = q . reshape ( B , q_h , q_w , dim )
rel_h = torch . einsum ( ' bhwc,hkc->bhwk ' , r_q , Rh )
rel_w = torch . einsum ( ' bhwc,wkc->bhwk ' , r_q , Rw )
attn = ( attn . view ( B , q_h , q_w , k_h , k_w ) + rel_h [ : , : , : , : , None ] + rel_w [ : , : , : , None , : ] ) . view (
B , q_h * q_w , k_h * k_w )
return attn
class PatchEmbed ( nn . Module ) :
"""
Image to Patch Embedding .
"""
def __init__ (
self ,
kernel_size : Tuple [ int , int ] = ( 16 , 16 ) ,
stride : Tuple [ int , int ] = ( 16 , 16 ) ,
padding : Tuple [ int , int ] = ( 0 , 0 ) ,
in_chans : int = 3 ,
embed_dim : int = 768 ,
) - > None :
"""
Args :
kernel_size ( Tuple ) : kernel size of the projection layer .
stride ( Tuple ) : stride of the projection layer .
padding ( Tuple ) : padding size of the projection layer .
in_chans ( int ) : Number of input image channels .
embed_dim ( int ) : Patch embedding dimension .
"""
super ( ) . __init__ ( )
self . proj = nn . Conv2d ( in_chans , embed_dim , kernel_size = kernel_size , stride = stride , padding = padding )
def forward ( self , x : torch . Tensor ) - > torch . Tensor :
x = self . proj ( x )
# B C H W -> B H W C
x = x . permute ( 0 , 2 , 3 , 1 )
return x