ultralytics 8.0.128 FastSAM autodownload and super() init (#3552)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-07-06 01:11:37 +02:00
committed by GitHub
parent 400f3f72a1
commit ad99246ff1
7 changed files with 90 additions and 72 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.127'
__version__ = '8.0.128'
from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR

View File

@ -21,6 +21,13 @@ from .predict import FastSAMPredictor
class FastSAM(YOLO):
def __init__(self, model='FastSAM-x.pt'):
# Call the __init__ method of the parent class (YOLO) with the updated default model
if model == 'FastSAM.pt':
model = 'FastSAM-x.pt'
super().__init__(model=model)
# any additional initialization code for FastSAM
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""

View File

@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.yolo.engine.results import Results

View File

@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
import cv2
@ -6,15 +8,6 @@ import numpy as np
import torch
from PIL import Image
try:
import clip # for linear_assignment
except (ImportError, AssertionError, AttributeError):
from ultralytics.yolo.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
import clip
class FastSAMPrompt:
@ -25,7 +18,17 @@ class FastSAMPrompt:
self.img_path = img_path
self.ori_img = cv2.imread(img_path)
def _segment_image(self, image, bbox):
# Import and assign clip
try:
import clip # for linear_assignment
except ImportError:
from ultralytics.yolo.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
import clip
self.clip = clip
@staticmethod
def _segment_image(image, bbox):
image_array = np.array(image)
segmented_image_array = np.zeros_like(image_array)
x1, y1, x2, y2 = bbox
@ -39,39 +42,40 @@ class FastSAMPrompt:
black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image
def _format_results(self, result, filter=0):
@staticmethod
def _format_results(result, filter=0):
annotations = []
n = len(result.masks.data)
for i in range(n):
annotation = {}
mask = result.masks.data[i] == 1.0
if torch.sum(mask) < filter:
continue
annotation['id'] = i
annotation['segmentation'] = mask.cpu().numpy()
annotation['bbox'] = result.boxes.data[i]
annotation['score'] = result.boxes.conf[i]
annotation = {
'id': i,
'segmentation': mask.cpu().numpy(),
'bbox': result.boxes.data[i],
'score': result.boxes.conf[i]}
annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation)
return annotations
def filter_masks(annotations): # filte the overlap mask
@staticmethod
def filter_masks(annotations): # filter the overlap mask
annotations.sort(key=lambda x: x['area'], reverse=True)
to_remove = set()
for i in range(0, len(annotations)):
for i in range(len(annotations)):
a = annotations[i]
for j in range(i + 1, len(annotations)):
b = annotations[j]
if i != j and j not in to_remove:
# check if
if b['area'] < a['area']:
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
to_remove.add(j)
if i != j and j not in to_remove and b['area'] < a['area'] and \
(a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
to_remove.add(j)
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
def _get_bbox_from_mask(self, mask):
@staticmethod
def _get_bbox_from_mask(mask):
mask = mask.astype(np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
x1, y1, w, h = cv2.boundingRect(contours[0])
@ -105,7 +109,7 @@ class FastSAMPrompt:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_h = image.shape[0]
original_w = image.shape[1]
# for MacOS only
# for macOS only
# plt.switch_backend('TkAgg')
plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin.
@ -164,10 +168,9 @@ class FastSAMPrompt:
interpolation=cv2.INTER_NEAREST,
)
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
contour_all.append(contour)
contour_all.extend(iter(contours))
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
@ -212,7 +215,7 @@ class FastSAMPrompt:
if random_color:
color = np.random.random((msak_sum, 1, 1, 3))
else:
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
@ -267,8 +270,8 @@ class FastSAMPrompt:
if random_color:
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
else:
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([30 / 255, 144 / 255, 1.0]).to(
annotation.device)
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
visual = torch.cat([color, transparency], dim=-1)
mask_image = torch.unsqueeze(annotation, -1) * visual
@ -304,7 +307,7 @@ class FastSAMPrompt:
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = clip.tokenize([search_text]).to(device)
tokenized_text = self.clip.tokenize([search_text]).to(device)
stacked_images = torch.stack(preprocessed_images)
image_features = model.encode_image(stacked_images)
text_features = model.encode_text(tokenized_text)
@ -352,10 +355,10 @@ class FastSAMPrompt:
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
bbox[0] = max(round(bbox[0]), 0)
bbox[1] = max(round(bbox[1]), 0)
bbox[2] = min(round(bbox[2]), w)
bbox[3] = min(round(bbox[3]), h)
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
@ -380,10 +383,7 @@ class FastSAMPrompt:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
for i, annotation in enumerate(masks):
if type(annotation) == dict:
mask = annotation['segmentation']
else:
mask = annotation
mask = annotation['segmentation'] if type(annotation) == dict else annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask += mask
@ -395,7 +395,7 @@ class FastSAMPrompt:
def text_prompt(self, text):
format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort()
max_idx = max_idx[-1]

View File

@ -1,16 +1,20 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
'''Adjust bounding boxes to stick to image border if they are within a certain threshold.
"""
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes: (n, 4)
image_shape: (height, width)
threshold: pixel threshold
boxes: (n, 4)
image_shape: (height, width)
threshold: pixel threshold
Returns:
adjusted_boxes: adjusted bounding boxes
'''
adjusted_boxes: adjusted bounding boxes
"""
# Image dimensions
h, w = image_shape
@ -25,14 +29,16 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
"""
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args:
box1: (4, )
boxes: (n, 4)
box1: (4, )
boxes: (n, 4)
Returns:
high_iou_indices: Indices of boxes with IoU > thres
'''
high_iou_indices: Indices of boxes with IoU > thres
"""
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections
x1 = torch.max(box1[0], boxes[:, 0])
@ -53,11 +59,7 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals
# compute the IoU
iou = intersection / union # Should be shape (n, )
if raw_output:
if iou.numel() == 0:
return 0
return iou
return 0 if iou.numel() == 0 else iou
# get indices of boxes with IoU > thres
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
return high_iou_indices
# return indices of boxes with IoU > thres
return torch.nonzero(iou > iou_thres).flatten()

View File

@ -20,6 +20,7 @@ GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in (''
[f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
[f'yolo_nas_{k}.pt' for k in 'sml'] + \
[f'sam_{k}.pt' for k in 'bl'] + \
[f'FastSAM-{k}.pt' for k in 'sx'] + \
[f'rtdetr-{k}.pt' for k in 'lx']
GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES]