`ultralytics 8.0.128` FastSAM autodownload and super() init (#3552)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 400f3f72a1
commit ad99246ff1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -44,17 +44,20 @@ To perform object detection on an image, use the `predict` method as shown below
from ultralytics import FastSAM
from ultralytics.yolo.fastsam import FastSAMPrompt
IMAGE_PATH = 'images/dog.jpg'
# Define image path and inference device
IMAGE_PATH = 'ultralytics/assets/bus.jpg'
DEVICE = 'cpu'
model = FastSAM('FastSAM.pt')
results = model(
IMAGE_PATH,
# Create a FastSAM model
model = FastSAM('FastSAM-s.pt') # or FastSAM-x.pt
# Run inference on an image
everything_results = model(IMAGE_PATH,
device=DEVICE,
retina_masks=True,
imgsz=1024,
conf=0.4,
iou=0.9,
)
iou=0.9)
prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE)
@ -83,8 +86,11 @@ Validation of the model on a dataset can be done as follows:
```python
from ultralytics import FastSAM
model = FastSAM('FastSAM.pt')
results = model.val(data='coco8-seg.yaml)
# Create a FastSAM model
model = FastSAM('FastSAM-s.pt') # or FastSAM-x.pt
# Validate the model
results = model.val(data='coco8-seg.yaml')
```
Please note that FastSAM only supports detection and segmentation of a single class of object. This means it will recognize and segment all objects as the same class. Therefore, when preparing the dataset, you need to convert all object category IDs to 0.

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.127'
__version__ = '8.0.128'
from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR

@ -21,6 +21,13 @@ from .predict import FastSAMPredictor
class FastSAM(YOLO):
def __init__(self, model='FastSAM-x.pt'):
# Call the __init__ method of the parent class (YOLO) with the updated default model
if model == 'FastSAM.pt':
model = 'FastSAM-x.pt'
super().__init__(model=model)
# any additional initialization code for FastSAM
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""

@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.yolo.engine.results import Results

@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
import cv2
@ -6,15 +8,6 @@ import numpy as np
import torch
from PIL import Image
try:
import clip # for linear_assignment
except (ImportError, AssertionError, AttributeError):
from ultralytics.yolo.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
import clip
class FastSAMPrompt:
@ -25,7 +18,17 @@ class FastSAMPrompt:
self.img_path = img_path
self.ori_img = cv2.imread(img_path)
def _segment_image(self, image, bbox):
# Import and assign clip
try:
import clip # for linear_assignment
except ImportError:
from ultralytics.yolo.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
import clip
self.clip = clip
@staticmethod
def _segment_image(image, bbox):
image_array = np.array(image)
segmented_image_array = np.zeros_like(image_array)
x1, y1, x2, y2 = bbox
@ -39,39 +42,40 @@ class FastSAMPrompt:
black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image
def _format_results(self, result, filter=0):
@staticmethod
def _format_results(result, filter=0):
annotations = []
n = len(result.masks.data)
for i in range(n):
annotation = {}
mask = result.masks.data[i] == 1.0
if torch.sum(mask) < filter:
continue
annotation['id'] = i
annotation['segmentation'] = mask.cpu().numpy()
annotation['bbox'] = result.boxes.data[i]
annotation['score'] = result.boxes.conf[i]
annotation = {
'id': i,
'segmentation': mask.cpu().numpy(),
'bbox': result.boxes.data[i],
'score': result.boxes.conf[i]}
annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation)
return annotations
def filter_masks(annotations): # filte the overlap mask
@staticmethod
def filter_masks(annotations): # filter the overlap mask
annotations.sort(key=lambda x: x['area'], reverse=True)
to_remove = set()
for i in range(0, len(annotations)):
for i in range(len(annotations)):
a = annotations[i]
for j in range(i + 1, len(annotations)):
b = annotations[j]
if i != j and j not in to_remove:
# check if
if b['area'] < a['area']:
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
if i != j and j not in to_remove and b['area'] < a['area'] and \
(a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
to_remove.add(j)
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
def _get_bbox_from_mask(self, mask):
@staticmethod
def _get_bbox_from_mask(mask):
mask = mask.astype(np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
x1, y1, w, h = cv2.boundingRect(contours[0])
@ -105,7 +109,7 @@ class FastSAMPrompt:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_h = image.shape[0]
original_w = image.shape[1]
# for MacOS only
# for macOS only
# plt.switch_backend('TkAgg')
plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin.
@ -164,10 +168,9 @@ class FastSAMPrompt:
interpolation=cv2.INTER_NEAREST,
)
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
contour_all.append(contour)
contour_all.extend(iter(contours))
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
@ -212,7 +215,7 @@ class FastSAMPrompt:
if random_color:
color = np.random.random((msak_sum, 1, 1, 3))
else:
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
@ -267,8 +270,8 @@ class FastSAMPrompt:
if random_color:
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
else:
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([30 / 255, 144 / 255, 1.0]).to(
annotation.device)
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
visual = torch.cat([color, transparency], dim=-1)
mask_image = torch.unsqueeze(annotation, -1) * visual
@ -304,7 +307,7 @@ class FastSAMPrompt:
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = clip.tokenize([search_text]).to(device)
tokenized_text = self.clip.tokenize([search_text]).to(device)
stacked_images = torch.stack(preprocessed_images)
image_features = model.encode_image(stacked_images)
text_features = model.encode_text(tokenized_text)
@ -352,10 +355,10 @@ class FastSAMPrompt:
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
bbox[0] = max(round(bbox[0]), 0)
bbox[1] = max(round(bbox[1]), 0)
bbox[2] = min(round(bbox[2]), w)
bbox[3] = min(round(bbox[3]), h)
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
@ -380,10 +383,7 @@ class FastSAMPrompt:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
for i, annotation in enumerate(masks):
if type(annotation) == dict:
mask = annotation['segmentation']
else:
mask = annotation
mask = annotation['segmentation'] if type(annotation) == dict else annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask += mask
@ -395,7 +395,7 @@ class FastSAMPrompt:
def text_prompt(self, text):
format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort()
max_idx = max_idx[-1]

@ -1,8 +1,12 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
'''Adjust bounding boxes to stick to image border if they are within a certain threshold.
"""
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes: (n, 4)
image_shape: (height, width)
@ -10,7 +14,7 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
Returns:
adjusted_boxes: adjusted bounding boxes
'''
"""
# Image dimensions
h, w = image_shape
@ -25,14 +29,16 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
"""
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args:
box1: (4, )
boxes: (n, 4)
Returns:
high_iou_indices: Indices of boxes with IoU > thres
'''
"""
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections
x1 = torch.max(box1[0], boxes[:, 0])
@ -53,11 +59,7 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals
# compute the IoU
iou = intersection / union # Should be shape (n, )
if raw_output:
if iou.numel() == 0:
return 0
return iou
# get indices of boxes with IoU > thres
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
return 0 if iou.numel() == 0 else iou
return high_iou_indices
# return indices of boxes with IoU > thres
return torch.nonzero(iou > iou_thres).flatten()

@ -20,6 +20,7 @@ GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in (''
[f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
[f'yolo_nas_{k}.pt' for k in 'sml'] + \
[f'sam_{k}.pt' for k in 'bl'] + \
[f'FastSAM-{k}.pt' for k in 'sx'] + \
[f'rtdetr-{k}.pt' for k in 'lx']
GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES]

Loading…
Cancel
Save