ultralytics 8.0.128
FastSAM autodownload and super() init (#3552)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -21,6 +21,13 @@ from .predict import FastSAMPredictor
|
||||
|
||||
class FastSAM(YOLO):
|
||||
|
||||
def __init__(self, model='FastSAM-x.pt'):
|
||||
# Call the __init__ method of the parent class (YOLO) with the updated default model
|
||||
if model == 'FastSAM.pt':
|
||||
model = 'FastSAM-x.pt'
|
||||
super().__init__(model=model)
|
||||
# any additional initialization code for FastSAM
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
|
@ -1,3 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
|
@ -1,3 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
|
||||
import cv2
|
||||
@ -6,15 +8,6 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import clip # for linear_assignment
|
||||
|
||||
except (ImportError, AssertionError, AttributeError):
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
|
||||
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
|
||||
import clip
|
||||
|
||||
|
||||
class FastSAMPrompt:
|
||||
|
||||
@ -25,7 +18,17 @@ class FastSAMPrompt:
|
||||
self.img_path = img_path
|
||||
self.ori_img = cv2.imread(img_path)
|
||||
|
||||
def _segment_image(self, image, bbox):
|
||||
# Import and assign clip
|
||||
try:
|
||||
import clip # for linear_assignment
|
||||
except ImportError:
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
|
||||
import clip
|
||||
self.clip = clip
|
||||
|
||||
@staticmethod
|
||||
def _segment_image(image, bbox):
|
||||
image_array = np.array(image)
|
||||
segmented_image_array = np.zeros_like(image_array)
|
||||
x1, y1, x2, y2 = bbox
|
||||
@ -39,39 +42,40 @@ class FastSAMPrompt:
|
||||
black_image.paste(segmented_image, mask=transparency_mask_image)
|
||||
return black_image
|
||||
|
||||
def _format_results(self, result, filter=0):
|
||||
@staticmethod
|
||||
def _format_results(result, filter=0):
|
||||
annotations = []
|
||||
n = len(result.masks.data)
|
||||
for i in range(n):
|
||||
annotation = {}
|
||||
mask = result.masks.data[i] == 1.0
|
||||
|
||||
if torch.sum(mask) < filter:
|
||||
continue
|
||||
annotation['id'] = i
|
||||
annotation['segmentation'] = mask.cpu().numpy()
|
||||
annotation['bbox'] = result.boxes.data[i]
|
||||
annotation['score'] = result.boxes.conf[i]
|
||||
annotation = {
|
||||
'id': i,
|
||||
'segmentation': mask.cpu().numpy(),
|
||||
'bbox': result.boxes.data[i],
|
||||
'score': result.boxes.conf[i]}
|
||||
annotation['area'] = annotation['segmentation'].sum()
|
||||
annotations.append(annotation)
|
||||
return annotations
|
||||
|
||||
def filter_masks(annotations): # filte the overlap mask
|
||||
@staticmethod
|
||||
def filter_masks(annotations): # filter the overlap mask
|
||||
annotations.sort(key=lambda x: x['area'], reverse=True)
|
||||
to_remove = set()
|
||||
for i in range(0, len(annotations)):
|
||||
for i in range(len(annotations)):
|
||||
a = annotations[i]
|
||||
for j in range(i + 1, len(annotations)):
|
||||
b = annotations[j]
|
||||
if i != j and j not in to_remove:
|
||||
# check if
|
||||
if b['area'] < a['area']:
|
||||
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
|
||||
to_remove.add(j)
|
||||
if i != j and j not in to_remove and b['area'] < a['area'] and \
|
||||
(a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
|
||||
to_remove.add(j)
|
||||
|
||||
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
|
||||
|
||||
def _get_bbox_from_mask(self, mask):
|
||||
@staticmethod
|
||||
def _get_bbox_from_mask(mask):
|
||||
mask = mask.astype(np.uint8)
|
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
x1, y1, w, h = cv2.boundingRect(contours[0])
|
||||
@ -105,7 +109,7 @@ class FastSAMPrompt:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
original_h = image.shape[0]
|
||||
original_w = image.shape[1]
|
||||
# for MacOS only
|
||||
# for macOS only
|
||||
# plt.switch_backend('TkAgg')
|
||||
plt.figure(figsize=(original_w / 100, original_h / 100))
|
||||
# Add subplot with no margin.
|
||||
@ -164,10 +168,9 @@ class FastSAMPrompt:
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for contour in contours:
|
||||
contour_all.append(contour)
|
||||
contour_all.extend(iter(contours))
|
||||
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
||||
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
|
||||
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
|
||||
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
||||
plt.imshow(contour_mask)
|
||||
|
||||
@ -212,7 +215,7 @@ class FastSAMPrompt:
|
||||
if random_color:
|
||||
color = np.random.random((msak_sum, 1, 1, 3))
|
||||
else:
|
||||
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
|
||||
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
|
||||
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
|
||||
visual = np.concatenate([color, transparency], axis=-1)
|
||||
mask_image = np.expand_dims(annotation, -1) * visual
|
||||
@ -267,8 +270,8 @@ class FastSAMPrompt:
|
||||
if random_color:
|
||||
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
|
||||
else:
|
||||
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
|
||||
30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
|
||||
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([30 / 255, 144 / 255, 1.0]).to(
|
||||
annotation.device)
|
||||
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
|
||||
visual = torch.cat([color, transparency], dim=-1)
|
||||
mask_image = torch.unsqueeze(annotation, -1) * visual
|
||||
@ -304,7 +307,7 @@ class FastSAMPrompt:
|
||||
@torch.no_grad()
|
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
||||
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
||||
tokenized_text = clip.tokenize([search_text]).to(device)
|
||||
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
||||
stacked_images = torch.stack(preprocessed_images)
|
||||
image_features = model.encode_image(stacked_images)
|
||||
text_features = model.encode_text(tokenized_text)
|
||||
@ -352,10 +355,10 @@ class FastSAMPrompt:
|
||||
int(bbox[1] * h / target_height),
|
||||
int(bbox[2] * w / target_width),
|
||||
int(bbox[3] * h / target_height), ]
|
||||
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
|
||||
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
|
||||
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
|
||||
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
|
||||
bbox[0] = max(round(bbox[0]), 0)
|
||||
bbox[1] = max(round(bbox[1]), 0)
|
||||
bbox[2] = min(round(bbox[2]), w)
|
||||
bbox[3] = min(round(bbox[3]), h)
|
||||
|
||||
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
||||
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
||||
@ -380,10 +383,7 @@ class FastSAMPrompt:
|
||||
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
||||
onemask = np.zeros((h, w))
|
||||
for i, annotation in enumerate(masks):
|
||||
if type(annotation) == dict:
|
||||
mask = annotation['segmentation']
|
||||
else:
|
||||
mask = annotation
|
||||
mask = annotation['segmentation'] if type(annotation) == dict else annotation
|
||||
for i, point in enumerate(points):
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
||||
onemask += mask
|
||||
@ -395,7 +395,7 @@ class FastSAMPrompt:
|
||||
def text_prompt(self, text):
|
||||
format_results = self._format_results(self.results[0], 0)
|
||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
||||
clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
|
||||
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
|
||||
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
||||
max_idx = scores.argsort()
|
||||
max_idx = max_idx[-1]
|
||||
|
@ -1,16 +1,20 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
||||
'''Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
"""
|
||||
Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
|
||||
Args:
|
||||
boxes: (n, 4)
|
||||
image_shape: (height, width)
|
||||
threshold: pixel threshold
|
||||
boxes: (n, 4)
|
||||
image_shape: (height, width)
|
||||
threshold: pixel threshold
|
||||
|
||||
Returns:
|
||||
adjusted_boxes: adjusted bounding boxes
|
||||
'''
|
||||
adjusted_boxes: adjusted bounding boxes
|
||||
"""
|
||||
|
||||
# Image dimensions
|
||||
h, w = image_shape
|
||||
@ -25,14 +29,16 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
||||
|
||||
|
||||
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
|
||||
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
|
||||
"""
|
||||
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
|
||||
|
||||
Args:
|
||||
box1: (4, )
|
||||
boxes: (n, 4)
|
||||
box1: (4, )
|
||||
boxes: (n, 4)
|
||||
|
||||
Returns:
|
||||
high_iou_indices: Indices of boxes with IoU > thres
|
||||
'''
|
||||
high_iou_indices: Indices of boxes with IoU > thres
|
||||
"""
|
||||
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
|
||||
# obtain coordinates for intersections
|
||||
x1 = torch.max(box1[0], boxes[:, 0])
|
||||
@ -53,11 +59,7 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals
|
||||
# compute the IoU
|
||||
iou = intersection / union # Should be shape (n, )
|
||||
if raw_output:
|
||||
if iou.numel() == 0:
|
||||
return 0
|
||||
return iou
|
||||
return 0 if iou.numel() == 0 else iou
|
||||
|
||||
# get indices of boxes with IoU > thres
|
||||
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
|
||||
|
||||
return high_iou_indices
|
||||
# return indices of boxes with IoU > thres
|
||||
return torch.nonzero(iou > iou_thres).flatten()
|
||||
|
@ -20,6 +20,7 @@ GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in (''
|
||||
[f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
|
||||
[f'yolo_nas_{k}.pt' for k in 'sml'] + \
|
||||
[f'sam_{k}.pt' for k in 'bl'] + \
|
||||
[f'FastSAM-{k}.pt' for k in 'sx'] + \
|
||||
[f'rtdetr-{k}.pt' for k in 'lx']
|
||||
GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES]
|
||||
|
||||
|
Reference in New Issue
Block a user