@ -26,7 +26,7 @@ class FastSAMPrompt:
import clip # for linear_assignment
import clip # for linear_assignment
except ImportError :
except ImportError :
from ultralytics . utils . checks import check_requirements
from ultralytics . utils . checks import check_requirements
check_requirements ( ' git+https://github.com/openai/CLIP.git ' ) # required before installing lap from source
check_requirements ( ' git+https://github.com/openai/CLIP.git ' )
import clip
import clip
self . clip = clip
self . clip = clip
@ -91,8 +91,6 @@ class FastSAMPrompt:
y1 = min ( y1 , y_t )
y1 = min ( y1 , y_t )
x2 = max ( x2 , x_t + w_t )
x2 = max ( x2 , x_t + w_t )
y2 = max ( y2 , y_t + h_t )
y2 = max ( y2 , y_t + h_t )
h = y2 - y1
w = x2 - x1
return [ x1 , y1 , x2 , y2 ]
return [ x1 , y1 , x2 , y2 ]
def plot ( self ,
def plot ( self ,
@ -104,9 +102,11 @@ class FastSAMPrompt:
mask_random_color = True ,
mask_random_color = True ,
better_quality = True ,
better_quality = True ,
retina = False ,
retina = False ,
with Contou rs= True ) :
with _countoue rs= True ) :
if isinstance ( annotations [ 0 ] , dict ) :
if isinstance ( annotations [ 0 ] , dict ) :
annotations = [ annotation [ ' segmentation ' ] for annotation in annotations ]
annotations = [ annotation [ ' segmentation ' ] for annotation in annotations ]
if isinstance ( annotations , torch . Tensor ) :
annotations = annotations . cpu ( ) . numpy ( )
result_name = os . path . basename ( self . img_path )
result_name = os . path . basename ( self . img_path )
image = self . ori_img
image = self . ori_img
image = cv2 . cvtColor ( image , cv2 . COLOR_BGR2RGB )
image = cv2 . cvtColor ( image , cv2 . COLOR_BGR2RGB )
@ -123,41 +123,22 @@ class FastSAMPrompt:
plt . imshow ( image )
plt . imshow ( image )
if better_quality :
if better_quality :
if isinstance ( annotations [ 0 ] , torch . Tensor ) :
annotations = np . array ( annotations . cpu ( ) )
for i , mask in enumerate ( annotations ) :
for i , mask in enumerate ( annotations ) :
mask = cv2 . morphologyEx ( mask . astype ( np . uint8 ) , cv2 . MORPH_CLOSE , np . ones ( ( 3 , 3 ) , np . uint8 ) )
mask = cv2 . morphologyEx ( mask . astype ( np . uint8 ) , cv2 . MORPH_CLOSE , np . ones ( ( 3 , 3 ) , np . uint8 ) )
annotations [ i ] = cv2 . morphologyEx ( mask . astype ( np . uint8 ) , cv2 . MORPH_OPEN , np . ones ( ( 8 , 8 ) , np . uint8 ) )
annotations [ i ] = cv2 . morphologyEx ( mask . astype ( np . uint8 ) , cv2 . MORPH_OPEN , np . ones ( ( 8 , 8 ) , np . uint8 ) )
if self . device == ' cpu ' :
self . fast_show_mask (
annotations = np . array ( annotations )
annotations ,
self . fast_show_mask (
plt . gca ( ) ,
annotations ,
random_color = mask_random_color ,
plt . gca ( ) ,
bbox = bbox ,
random_color = mask_random_color ,
points = points ,
bbox = bbox ,
pointlabel = point_label ,
points = points ,
retinamask = retina ,
pointlabel = point_label ,
target_height = original_h ,
retinamask = retina ,
target_width = original_w ,
target_height = original_h ,
)
target_width = original_w ,
)
if with_countouers :
else :
if isinstance ( annotations [ 0 ] , np . ndarray ) :
annotations = torch . from_numpy ( annotations )
self . fast_show_mask_gpu (
annotations ,
plt . gca ( ) ,
random_color = mask_random_color ,
bbox = bbox ,
points = points ,
pointlabel = point_label ,
retinamask = retina ,
target_height = original_h ,
target_width = original_w ,
)
if isinstance ( annotations , torch . Tensor ) :
annotations = annotations . cpu ( ) . numpy ( )
if withContours :
contour_all = [ ]
contour_all = [ ]
temp = np . zeros ( ( original_h , original_w , 1 ) )
temp = np . zeros ( ( original_h , original_w , 1 ) )
for i , mask in enumerate ( annotations ) :
for i , mask in enumerate ( annotations ) :
@ -184,8 +165,8 @@ class FastSAMPrompt:
LOGGER . info ( f ' Saved to { save_path . absolute ( ) } ' )
LOGGER . info ( f ' Saved to { save_path . absolute ( ) } ' )
# CPU post process
# CPU post process
@staticmethod
def fast_show_mask (
def fast_show_mask (
self ,
annotation ,
annotation ,
ax ,
ax ,
random_color = False ,
random_color = False ,
@ -196,32 +177,29 @@ class FastSAMPrompt:
target_height = 960 ,
target_height = 960 ,
target_width = 960 ,
target_width = 960 ,
) :
) :
msak_sum = annotation . shape [ 0 ]
n , h , w = annotation . shape # batch, height, width
height = annotation . shape [ 1 ]
weight = annotation . shape [ 2 ]
# 将annotation 按照面积 排序
areas = np . sum ( annotation , axis = ( 1 , 2 ) )
areas = np . sum ( annotation , axis = ( 1 , 2 ) )
sorted_indices = np . argsort ( areas )
annotation = annotation [ np . argsort ( areas ) ]
annotation = annotation [ sorted_indices ]
index = ( annotation != 0 ) . argmax ( axis = 0 )
index = ( annotation != 0 ) . argmax ( axis = 0 )
if random_color :
if random_color :
color = np . random . random ( ( msak_sum , 1 , 1 , 3 ) )
color = np . random . random ( ( n , 1 , 1 , 3 ) )
else :
else :
color = np . ones ( ( msak_sum , 1 , 1 , 3 ) ) * np . array ( [ 30 / 255 , 144 / 255 , 1.0 ] )
color = np . ones ( ( n , 1 , 1 , 3 ) ) * np . array ( [ 30 / 255 , 144 / 255 , 1.0 ] )
transparency = np . ones ( ( msak_sum , 1 , 1 , 1 ) ) * 0.6
transparency = np . ones ( ( n , 1 , 1 , 1 ) ) * 0.6
visual = np . concatenate ( [ color , transparency ] , axis = - 1 )
visual = np . concatenate ( [ color , transparency ] , axis = - 1 )
mask_image = np . expand_dims ( annotation , - 1 ) * visual
mask_image = np . expand_dims ( annotation , - 1 ) * visual
show = np . zeros ( ( h eight , w eight , 4 ) )
show = np . zeros ( ( h , w , 4 ) )
h_indices , w_indices = np . meshgrid ( np . arange ( h eight ) , np . arange ( w eight ) , indexing = ' ij ' )
h_indices , w_indices = np . meshgrid ( np . arange ( h ) , np . arange ( w ) , indexing = ' ij ' )
indices = ( index [ h_indices , w_indices ] , h_indices , w_indices , slice ( None ) )
indices = ( index [ h_indices , w_indices ] , h_indices , w_indices , slice ( None ) )
# 使用向量化索引更新show的值
show [ h_indices , w_indices , : ] = mask_image [ indices ]
show [ h_indices , w_indices , : ] = mask_image [ indices ]
if bbox is not None :
if bbox is not None :
x1 , y1 , x2 , y2 = bbox
x1 , y1 , x2 , y2 = bbox
ax . add_patch ( plt . Rectangle ( ( x1 , y1 ) , x2 - x1 , y2 - y1 , fill = False , edgecolor = ' b ' , linewidth = 1 ) )
ax . add_patch ( plt . Rectangle ( ( x1 , y1 ) , x2 - x1 , y2 - y1 , fill = False , edgecolor = ' b ' , linewidth = 1 ) )
# d raw point
# D raw point
if points is not None :
if points is not None :
plt . scatter (
plt . scatter (
[ point [ 0 ] for i , point in enumerate ( points ) if pointlabel [ i ] == 1 ] ,
[ point [ 0 ] for i , point in enumerate ( points ) if pointlabel [ i ] == 1 ] ,
@ -240,63 +218,6 @@ class FastSAMPrompt:
show = cv2 . resize ( show , ( target_width , target_height ) , interpolation = cv2 . INTER_NEAREST )
show = cv2 . resize ( show , ( target_width , target_height ) , interpolation = cv2 . INTER_NEAREST )
ax . imshow ( show )
ax . imshow ( show )
def fast_show_mask_gpu (
self ,
annotation ,
ax ,
random_color = False ,
bbox = None ,
points = None ,
pointlabel = None ,
retinamask = True ,
target_height = 960 ,
target_width = 960 ,
) :
msak_sum = annotation . shape [ 0 ]
height = annotation . shape [ 1 ]
weight = annotation . shape [ 2 ]
areas = torch . sum ( annotation , dim = ( 1 , 2 ) )
sorted_indices = torch . argsort ( areas , descending = False )
annotation = annotation [ sorted_indices ]
# 找每个位置第一个非零值下标
index = ( annotation != 0 ) . to ( torch . long ) . argmax ( dim = 0 )
if random_color :
color = torch . rand ( ( msak_sum , 1 , 1 , 3 ) ) . to ( annotation . device )
else :
color = torch . ones ( ( msak_sum , 1 , 1 , 3 ) ) . to ( annotation . device ) * torch . tensor ( [ 30 / 255 , 144 / 255 , 1.0 ] ) . to (
annotation . device )
transparency = torch . ones ( ( msak_sum , 1 , 1 , 1 ) ) . to ( annotation . device ) * 0.6
visual = torch . cat ( [ color , transparency ] , dim = - 1 )
mask_image = torch . unsqueeze ( annotation , - 1 ) * visual
# 按index取数, index指每个位置选哪个batch的数, 把mask_image转成一个batch的形式
show = torch . zeros ( ( height , weight , 4 ) ) . to ( annotation . device )
h_indices , w_indices = torch . meshgrid ( torch . arange ( height ) , torch . arange ( weight ) , indexing = ' ij ' )
indices = ( index [ h_indices , w_indices ] , h_indices , w_indices , slice ( None ) )
# 使用向量化索引更新show的值
show [ h_indices , w_indices , : ] = mask_image [ indices ]
show_cpu = show . cpu ( ) . numpy ( )
if bbox is not None :
x1 , y1 , x2 , y2 = bbox
ax . add_patch ( plt . Rectangle ( ( x1 , y1 ) , x2 - x1 , y2 - y1 , fill = False , edgecolor = ' b ' , linewidth = 1 ) )
# draw point
if points is not None :
plt . scatter (
[ point [ 0 ] for i , point in enumerate ( points ) if pointlabel [ i ] == 1 ] ,
[ point [ 1 ] for i , point in enumerate ( points ) if pointlabel [ i ] == 1 ] ,
s = 20 ,
c = ' y ' ,
)
plt . scatter (
[ point [ 0 ] for i , point in enumerate ( points ) if pointlabel [ i ] == 0 ] ,
[ point [ 1 ] for i , point in enumerate ( points ) if pointlabel [ i ] == 0 ] ,
s = 20 ,
c = ' m ' ,
)
if not retinamask :
show_cpu = cv2 . resize ( show_cpu , ( target_width , target_height ) , interpolation = cv2 . INTER_NEAREST )
ax . imshow ( show_cpu )
# clip
@torch.no_grad ( )
@torch.no_grad ( )
def retrieve ( self , model , preprocess , elements , search_text : str , device ) - > int :
def retrieve ( self , model , preprocess , elements , search_text : str , device ) - > int :
preprocessed_images = [ preprocess ( image ) . to ( device ) for image in elements ]
preprocessed_images = [ preprocess ( image ) . to ( device ) for image in elements ]