`ultralytics 8.0.128` FastSAM autodownload and super() init (#3552)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 400f3f72a1
commit ad99246ff1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -44,17 +44,20 @@ To perform object detection on an image, use the `predict` method as shown below
from ultralytics import FastSAM from ultralytics import FastSAM
from ultralytics.yolo.fastsam import FastSAMPrompt from ultralytics.yolo.fastsam import FastSAMPrompt
IMAGE_PATH = 'images/dog.jpg' # Define image path and inference device
IMAGE_PATH = 'ultralytics/assets/bus.jpg'
DEVICE = 'cpu' DEVICE = 'cpu'
model = FastSAM('FastSAM.pt')
results = model( # Create a FastSAM model
IMAGE_PATH, model = FastSAM('FastSAM-s.pt') # or FastSAM-x.pt
# Run inference on an image
everything_results = model(IMAGE_PATH,
device=DEVICE, device=DEVICE,
retina_masks=True, retina_masks=True,
imgsz=1024, imgsz=1024,
conf=0.4, conf=0.4,
iou=0.9, iou=0.9)
)
prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE) prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE)
@ -83,8 +86,11 @@ Validation of the model on a dataset can be done as follows:
```python ```python
from ultralytics import FastSAM from ultralytics import FastSAM
model = FastSAM('FastSAM.pt') # Create a FastSAM model
results = model.val(data='coco8-seg.yaml) model = FastSAM('FastSAM-s.pt') # or FastSAM-x.pt
# Validate the model
results = model.val(data='coco8-seg.yaml')
``` ```
Please note that FastSAM only supports detection and segmentation of a single class of object. This means it will recognize and segment all objects as the same class. Therefore, when preparing the dataset, you need to convert all object category IDs to 0. Please note that FastSAM only supports detection and segmentation of a single class of object. This means it will recognize and segment all objects as the same class. Therefore, when preparing the dataset, you need to convert all object category IDs to 0.

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.127' __version__ = '8.0.128'
from ultralytics.hub import start from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR from ultralytics.vit.rtdetr import RTDETR

@ -21,6 +21,13 @@ from .predict import FastSAMPredictor
class FastSAM(YOLO): class FastSAM(YOLO):
def __init__(self, model='FastSAM-x.pt'):
# Call the __init__ method of the parent class (YOLO) with the updated default model
if model == 'FastSAM.pt':
model = 'FastSAM-x.pt'
super().__init__(model=model)
# any additional initialization code for FastSAM
@smart_inference_mode() @smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs): def predict(self, source=None, stream=False, **kwargs):
""" """

@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch import torch
from ultralytics.yolo.engine.results import Results from ultralytics.yolo.engine.results import Results

@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os import os
import cv2 import cv2
@ -6,15 +8,6 @@ import numpy as np
import torch import torch
from PIL import Image from PIL import Image
try:
import clip # for linear_assignment
except (ImportError, AssertionError, AttributeError):
from ultralytics.yolo.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
import clip
class FastSAMPrompt: class FastSAMPrompt:
@ -25,7 +18,17 @@ class FastSAMPrompt:
self.img_path = img_path self.img_path = img_path
self.ori_img = cv2.imread(img_path) self.ori_img = cv2.imread(img_path)
def _segment_image(self, image, bbox): # Import and assign clip
try:
import clip # for linear_assignment
except ImportError:
from ultralytics.yolo.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
import clip
self.clip = clip
@staticmethod
def _segment_image(image, bbox):
image_array = np.array(image) image_array = np.array(image)
segmented_image_array = np.zeros_like(image_array) segmented_image_array = np.zeros_like(image_array)
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
@ -39,39 +42,40 @@ class FastSAMPrompt:
black_image.paste(segmented_image, mask=transparency_mask_image) black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image return black_image
def _format_results(self, result, filter=0): @staticmethod
def _format_results(result, filter=0):
annotations = [] annotations = []
n = len(result.masks.data) n = len(result.masks.data)
for i in range(n): for i in range(n):
annotation = {}
mask = result.masks.data[i] == 1.0 mask = result.masks.data[i] == 1.0
if torch.sum(mask) < filter: if torch.sum(mask) < filter:
continue continue
annotation['id'] = i annotation = {
annotation['segmentation'] = mask.cpu().numpy() 'id': i,
annotation['bbox'] = result.boxes.data[i] 'segmentation': mask.cpu().numpy(),
annotation['score'] = result.boxes.conf[i] 'bbox': result.boxes.data[i],
'score': result.boxes.conf[i]}
annotation['area'] = annotation['segmentation'].sum() annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation) annotations.append(annotation)
return annotations return annotations
def filter_masks(annotations): # filte the overlap mask @staticmethod
def filter_masks(annotations): # filter the overlap mask
annotations.sort(key=lambda x: x['area'], reverse=True) annotations.sort(key=lambda x: x['area'], reverse=True)
to_remove = set() to_remove = set()
for i in range(0, len(annotations)): for i in range(len(annotations)):
a = annotations[i] a = annotations[i]
for j in range(i + 1, len(annotations)): for j in range(i + 1, len(annotations)):
b = annotations[j] b = annotations[j]
if i != j and j not in to_remove: if i != j and j not in to_remove and b['area'] < a['area'] and \
# check if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
if b['area'] < a['area']:
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
to_remove.add(j) to_remove.add(j)
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
def _get_bbox_from_mask(self, mask): @staticmethod
def _get_bbox_from_mask(mask):
mask = mask.astype(np.uint8) mask = mask.astype(np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
x1, y1, w, h = cv2.boundingRect(contours[0]) x1, y1, w, h = cv2.boundingRect(contours[0])
@ -105,7 +109,7 @@ class FastSAMPrompt:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_h = image.shape[0] original_h = image.shape[0]
original_w = image.shape[1] original_w = image.shape[1]
# for MacOS only # for macOS only
# plt.switch_backend('TkAgg') # plt.switch_backend('TkAgg')
plt.figure(figsize=(original_w / 100, original_h / 100)) plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin. # Add subplot with no margin.
@ -164,10 +168,9 @@ class FastSAMPrompt:
interpolation=cv2.INTER_NEAREST, interpolation=cv2.INTER_NEAREST,
) )
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours: contour_all.extend(iter(contours))
contour_all.append(contour)
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8]) color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1) contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask) plt.imshow(contour_mask)
@ -212,7 +215,7 @@ class FastSAMPrompt:
if random_color: if random_color:
color = np.random.random((msak_sum, 1, 1, 3)) color = np.random.random((msak_sum, 1, 1, 3))
else: else:
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1) visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual mask_image = np.expand_dims(annotation, -1) * visual
@ -267,8 +270,8 @@ class FastSAMPrompt:
if random_color: if random_color:
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
else: else:
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([30 / 255, 144 / 255, 1.0]).to(
30 / 255, 144 / 255, 255 / 255]).to(annotation.device) annotation.device)
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
visual = torch.cat([color, transparency], dim=-1) visual = torch.cat([color, transparency], dim=-1)
mask_image = torch.unsqueeze(annotation, -1) * visual mask_image = torch.unsqueeze(annotation, -1) * visual
@ -304,7 +307,7 @@ class FastSAMPrompt:
@torch.no_grad() @torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int: def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
preprocessed_images = [preprocess(image).to(device) for image in elements] preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = clip.tokenize([search_text]).to(device) tokenized_text = self.clip.tokenize([search_text]).to(device)
stacked_images = torch.stack(preprocessed_images) stacked_images = torch.stack(preprocessed_images)
image_features = model.encode_image(stacked_images) image_features = model.encode_image(stacked_images)
text_features = model.encode_text(tokenized_text) text_features = model.encode_text(tokenized_text)
@ -352,10 +355,10 @@ class FastSAMPrompt:
int(bbox[1] * h / target_height), int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width), int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ] int(bbox[3] * h / target_height), ]
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 bbox[0] = max(round(bbox[0]), 0)
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 bbox[1] = max(round(bbox[1]), 0)
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w bbox[2] = min(round(bbox[2]), w)
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h bbox[3] = min(round(bbox[3]), h)
# IoUs = torch.zeros(len(masks), dtype=torch.float32) # IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
@ -380,10 +383,7 @@ class FastSAMPrompt:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w)) onemask = np.zeros((h, w))
for i, annotation in enumerate(masks): for i, annotation in enumerate(masks):
if type(annotation) == dict: mask = annotation['segmentation'] if type(annotation) == dict else annotation
mask = annotation['segmentation']
else:
mask = annotation
for i, point in enumerate(points): for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask += mask onemask += mask
@ -395,7 +395,7 @@ class FastSAMPrompt:
def text_prompt(self, text): def text_prompt(self, text):
format_results = self._format_results(self.results[0], 0) format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = clip.load('ViT-B/32', device=self.device) clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort() max_idx = scores.argsort()
max_idx = max_idx[-1] max_idx = max_idx[-1]

@ -1,8 +1,12 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch import torch
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
'''Adjust bounding boxes to stick to image border if they are within a certain threshold. """
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args: Args:
boxes: (n, 4) boxes: (n, 4)
image_shape: (height, width) image_shape: (height, width)
@ -10,7 +14,7 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
Returns: Returns:
adjusted_boxes: adjusted bounding boxes adjusted_boxes: adjusted bounding boxes
''' """
# Image dimensions # Image dimensions
h, w = image_shape h, w = image_shape
@ -25,14 +29,16 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False): def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. """
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args: Args:
box1: (4, ) box1: (4, )
boxes: (n, 4) boxes: (n, 4)
Returns: Returns:
high_iou_indices: Indices of boxes with IoU > thres high_iou_indices: Indices of boxes with IoU > thres
''' """
boxes = adjust_bboxes_to_image_border(boxes, image_shape) boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections # obtain coordinates for intersections
x1 = torch.max(box1[0], boxes[:, 0]) x1 = torch.max(box1[0], boxes[:, 0])
@ -53,11 +59,7 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals
# compute the IoU # compute the IoU
iou = intersection / union # Should be shape (n, ) iou = intersection / union # Should be shape (n, )
if raw_output: if raw_output:
if iou.numel() == 0: return 0 if iou.numel() == 0 else iou
return 0
return iou
# get indices of boxes with IoU > thres
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
return high_iou_indices # return indices of boxes with IoU > thres
return torch.nonzero(iou > iou_thres).flatten()

@ -20,6 +20,7 @@ GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in (''
[f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \ [f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
[f'yolo_nas_{k}.pt' for k in 'sml'] + \ [f'yolo_nas_{k}.pt' for k in 'sml'] + \
[f'sam_{k}.pt' for k in 'bl'] + \ [f'sam_{k}.pt' for k in 'bl'] + \
[f'FastSAM-{k}.pt' for k in 'sx'] + \
[f'rtdetr-{k}.pt' for k in 'lx'] [f'rtdetr-{k}.pt' for k in 'lx']
GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES] GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES]

Loading…
Cancel
Save