@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
import cv2
@ -6,15 +8,6 @@ import numpy as np
import torch
from PIL import Image
try :
import clip # for linear_assignment
except ( ImportError , AssertionError , AttributeError ) :
from ultralytics . yolo . utils . checks import check_requirements
check_requirements ( ' git+https://github.com/openai/CLIP.git ' ) # required before installing lap from source
import clip
class FastSAMPrompt :
@ -25,7 +18,17 @@ class FastSAMPrompt:
self . img_path = img_path
self . ori_img = cv2 . imread ( img_path )
def _segment_image ( self , image , bbox ) :
# Import and assign clip
try :
import clip # for linear_assignment
except ImportError :
from ultralytics . yolo . utils . checks import check_requirements
check_requirements ( ' git+https://github.com/openai/CLIP.git ' ) # required before installing lap from source
import clip
self . clip = clip
@staticmethod
def _segment_image ( image , bbox ) :
image_array = np . array ( image )
segmented_image_array = np . zeros_like ( image_array )
x1 , y1 , x2 , y2 = bbox
@ -39,39 +42,40 @@ class FastSAMPrompt:
black_image . paste ( segmented_image , mask = transparency_mask_image )
return black_image
def _format_results ( self , result , filter = 0 ) :
@staticmethod
def _format_results ( result , filter = 0 ) :
annotations = [ ]
n = len ( result . masks . data )
for i in range ( n ) :
annotation = { }
mask = result . masks . data [ i ] == 1.0
if torch . sum ( mask ) < filter :
continue
annotation [ ' id ' ] = i
annotation [ ' segmentation ' ] = mask . cpu ( ) . numpy ( )
annotation [ ' bbox ' ] = result . boxes . data [ i ]
annotation [ ' score ' ] = result . boxes . conf [ i ]
annotation = {
' id ' : i ,
' segmentation ' : mask . cpu ( ) . numpy ( ) ,
' bbox ' : result . boxes . data [ i ] ,
' score ' : result . boxes . conf [ i ] }
annotation [ ' area ' ] = annotation [ ' segmentation ' ] . sum ( )
annotations . append ( annotation )
return annotations
def filter_masks ( annotations ) : # filte the overlap mask
@staticmethod
def filter_masks ( annotations ) : # filter the overlap mask
annotations . sort ( key = lambda x : x [ ' area ' ] , reverse = True )
to_remove = set ( )
for i in range ( 0 , len ( annotations ) ) :
for i in range ( len ( annotations ) ) :
a = annotations [ i ]
for j in range ( i + 1 , len ( annotations ) ) :
b = annotations [ j ]
if i != j and j not in to_remove :
# check if
if b [ ' area ' ] < a [ ' area ' ] :
if ( a [ ' segmentation ' ] & b [ ' segmentation ' ] ) . sum ( ) / b [ ' segmentation ' ] . sum ( ) > 0.8 :
to_remove . add ( j )
if i != j and j not in to_remove and b [ ' area ' ] < a [ ' area ' ] and \
( a [ ' segmentation ' ] & b [ ' segmentation ' ] ) . sum ( ) / b [ ' segmentation ' ] . sum ( ) > 0.8 :
to_remove . add ( j )
return [ a for i , a in enumerate ( annotations ) if i not in to_remove ] , to_remove
def _get_bbox_from_mask ( self , mask ) :
@staticmethod
def _get_bbox_from_mask ( mask ) :
mask = mask . astype ( np . uint8 )
contours , hierarchy = cv2 . findContours ( mask , cv2 . RETR_EXTERNAL , cv2 . CHAIN_APPROX_SIMPLE )
x1 , y1 , w , h = cv2 . boundingRect ( contours [ 0 ] )
@ -105,7 +109,7 @@ class FastSAMPrompt:
image = cv2 . cvtColor ( image , cv2 . COLOR_BGR2RGB )
original_h = image . shape [ 0 ]
original_w = image . shape [ 1 ]
# for M acOS only
# for m acOS only
# plt.switch_backend('TkAgg')
plt . figure ( figsize = ( original_w / 100 , original_h / 100 ) )
# Add subplot with no margin.
@ -164,10 +168,9 @@ class FastSAMPrompt:
interpolation = cv2 . INTER_NEAREST ,
)
contours , hierarchy = cv2 . findContours ( annotation , cv2 . RETR_TREE , cv2 . CHAIN_APPROX_SIMPLE )
for contour in contours :
contour_all . append ( contour )
contour_all . extend ( iter ( contours ) )
cv2 . drawContours ( temp , contour_all , - 1 , ( 255 , 255 , 255 ) , 2 )
color = np . array ( [ 0 / 255 , 0 / 255 , 255 / 255 , 0.8 ] )
color = np . array ( [ 0 / 255 , 0 / 255 , 1.0 , 0.8 ] )
contour_mask = temp / 255 * color . reshape ( 1 , 1 , - 1 )
plt . imshow ( contour_mask )
@ -212,7 +215,7 @@ class FastSAMPrompt:
if random_color :
color = np . random . random ( ( msak_sum , 1 , 1 , 3 ) )
else :
color = np . ones ( ( msak_sum , 1 , 1 , 3 ) ) * np . array ( [ 30 / 255 , 144 / 255 , 255 / 255 ] )
color = np . ones ( ( msak_sum , 1 , 1 , 3 ) ) * np . array ( [ 30 / 255 , 144 / 255 , 1.0 ] )
transparency = np . ones ( ( msak_sum , 1 , 1 , 1 ) ) * 0.6
visual = np . concatenate ( [ color , transparency ] , axis = - 1 )
mask_image = np . expand_dims ( annotation , - 1 ) * visual
@ -267,8 +270,8 @@ class FastSAMPrompt:
if random_color :
color = torch . rand ( ( msak_sum , 1 , 1 , 3 ) ) . to ( annotation . device )
else :
color = torch . ones ( ( msak_sum , 1 , 1 , 3 ) ) . to ( annotation . device ) * torch . tensor ( [
30 / 255 , 144 / 255 , 255 / 255 ] ) . to ( annotation . device )
color = torch . ones ( ( msak_sum , 1 , 1 , 3 ) ) . to ( annotation . device ) * torch . tensor ( [ 30 / 255 , 144 / 255 , 1.0 ] ) . to (
annotation . device )
transparency = torch . ones ( ( msak_sum , 1 , 1 , 1 ) ) . to ( annotation . device ) * 0.6
visual = torch . cat ( [ color , transparency ] , dim = - 1 )
mask_image = torch . unsqueeze ( annotation , - 1 ) * visual
@ -304,7 +307,7 @@ class FastSAMPrompt:
@torch.no_grad ( )
def retrieve ( self , model , preprocess , elements , search_text : str , device ) - > int :
preprocessed_images = [ preprocess ( image ) . to ( device ) for image in elements ]
tokenized_text = clip . tokenize ( [ search_text ] ) . to ( device )
tokenized_text = self . clip . tokenize ( [ search_text ] ) . to ( device )
stacked_images = torch . stack ( preprocessed_images )
image_features = model . encode_image ( stacked_images )
text_features = model . encode_text ( tokenized_text )
@ -352,10 +355,10 @@ class FastSAMPrompt:
int ( bbox [ 1 ] * h / target_height ) ,
int ( bbox [ 2 ] * w / target_width ) ,
int ( bbox [ 3 ] * h / target_height ) , ]
bbox [ 0 ] = round( bbox [ 0 ] ) if round ( bbox [ 0 ] ) > 0 else 0
bbox [ 1 ] = round( bbox [ 1 ] ) if round ( bbox [ 1 ] ) > 0 else 0
bbox [ 2 ] = round( bbox [ 2 ] ) if round ( bbox [ 2 ] ) < w else w
bbox [ 3 ] = round( bbox [ 3 ] ) if round ( bbox [ 3 ] ) < h else h
bbox [ 0 ] = max( round ( bbox [ 0 ] ) , 0 )
bbox [ 1 ] = max( round ( bbox [ 1 ] ) , 0 )
bbox [ 2 ] = min( round ( bbox [ 2 ] ) , w )
bbox [ 3 ] = min( round ( bbox [ 3 ] ) , h )
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = ( bbox [ 3 ] - bbox [ 1 ] ) * ( bbox [ 2 ] - bbox [ 0 ] )
@ -380,10 +383,7 @@ class FastSAMPrompt:
points = [ [ int ( point [ 0 ] * w / target_width ) , int ( point [ 1 ] * h / target_height ) ] for point in points ]
onemask = np . zeros ( ( h , w ) )
for i , annotation in enumerate ( masks ) :
if type ( annotation ) == dict :
mask = annotation [ ' segmentation ' ]
else :
mask = annotation
mask = annotation [ ' segmentation ' ] if type ( annotation ) == dict else annotation
for i , point in enumerate ( points ) :
if mask [ point [ 1 ] , point [ 0 ] ] == 1 and pointlabel [ i ] == 1 :
onemask + = mask
@ -395,7 +395,7 @@ class FastSAMPrompt:
def text_prompt ( self , text ) :
format_results = self . _format_results ( self . results [ 0 ] , 0 )
cropped_boxes , cropped_images , not_crop , filter_id , annotations = self . _crop_image ( format_results )
clip_model , preprocess = clip . load ( ' ViT-B/32 ' , device = self . device )
clip_model , preprocess = self . clip . load ( ' ViT-B/32 ' , device = self . device )
scores = self . retrieve ( clip_model , preprocess , cropped_boxes , text , device = self . device )
max_idx = scores . argsort ( )
max_idx = max_idx [ - 1 ]