ultralytics 8.0.128
FastSAM autodownload and super() init (#3552)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -44,17 +44,20 @@ To perform object detection on an image, use the `predict` method as shown below
|
||||
from ultralytics import FastSAM
|
||||
from ultralytics.yolo.fastsam import FastSAMPrompt
|
||||
|
||||
IMAGE_PATH = 'images/dog.jpg'
|
||||
# Define image path and inference device
|
||||
IMAGE_PATH = 'ultralytics/assets/bus.jpg'
|
||||
DEVICE = 'cpu'
|
||||
model = FastSAM('FastSAM.pt')
|
||||
results = model(
|
||||
IMAGE_PATH,
|
||||
device=DEVICE,
|
||||
retina_masks=True,
|
||||
imgsz=1024,
|
||||
conf=0.4,
|
||||
iou=0.9,
|
||||
)
|
||||
|
||||
# Create a FastSAM model
|
||||
model = FastSAM('FastSAM-s.pt') # or FastSAM-x.pt
|
||||
|
||||
# Run inference on an image
|
||||
everything_results = model(IMAGE_PATH,
|
||||
device=DEVICE,
|
||||
retina_masks=True,
|
||||
imgsz=1024,
|
||||
conf=0.4,
|
||||
iou=0.9)
|
||||
|
||||
prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE)
|
||||
|
||||
@ -83,8 +86,11 @@ Validation of the model on a dataset can be done as follows:
|
||||
```python
|
||||
from ultralytics import FastSAM
|
||||
|
||||
model = FastSAM('FastSAM.pt')
|
||||
results = model.val(data='coco8-seg.yaml)
|
||||
# Create a FastSAM model
|
||||
model = FastSAM('FastSAM-s.pt') # or FastSAM-x.pt
|
||||
|
||||
# Validate the model
|
||||
results = model.val(data='coco8-seg.yaml')
|
||||
```
|
||||
|
||||
Please note that FastSAM only supports detection and segmentation of a single class of object. This means it will recognize and segment all objects as the same class. Therefore, when preparing the dataset, you need to convert all object category IDs to 0.
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.127'
|
||||
__version__ = '8.0.128'
|
||||
|
||||
from ultralytics.hub import start
|
||||
from ultralytics.vit.rtdetr import RTDETR
|
||||
|
@ -21,6 +21,13 @@ from .predict import FastSAMPredictor
|
||||
|
||||
class FastSAM(YOLO):
|
||||
|
||||
def __init__(self, model='FastSAM-x.pt'):
|
||||
# Call the __init__ method of the parent class (YOLO) with the updated default model
|
||||
if model == 'FastSAM.pt':
|
||||
model = 'FastSAM-x.pt'
|
||||
super().__init__(model=model)
|
||||
# any additional initialization code for FastSAM
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
|
@ -1,3 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
|
@ -1,3 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
|
||||
import cv2
|
||||
@ -6,15 +8,6 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import clip # for linear_assignment
|
||||
|
||||
except (ImportError, AssertionError, AttributeError):
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
|
||||
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
|
||||
import clip
|
||||
|
||||
|
||||
class FastSAMPrompt:
|
||||
|
||||
@ -25,7 +18,17 @@ class FastSAMPrompt:
|
||||
self.img_path = img_path
|
||||
self.ori_img = cv2.imread(img_path)
|
||||
|
||||
def _segment_image(self, image, bbox):
|
||||
# Import and assign clip
|
||||
try:
|
||||
import clip # for linear_assignment
|
||||
except ImportError:
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
|
||||
import clip
|
||||
self.clip = clip
|
||||
|
||||
@staticmethod
|
||||
def _segment_image(image, bbox):
|
||||
image_array = np.array(image)
|
||||
segmented_image_array = np.zeros_like(image_array)
|
||||
x1, y1, x2, y2 = bbox
|
||||
@ -39,39 +42,40 @@ class FastSAMPrompt:
|
||||
black_image.paste(segmented_image, mask=transparency_mask_image)
|
||||
return black_image
|
||||
|
||||
def _format_results(self, result, filter=0):
|
||||
@staticmethod
|
||||
def _format_results(result, filter=0):
|
||||
annotations = []
|
||||
n = len(result.masks.data)
|
||||
for i in range(n):
|
||||
annotation = {}
|
||||
mask = result.masks.data[i] == 1.0
|
||||
|
||||
if torch.sum(mask) < filter:
|
||||
continue
|
||||
annotation['id'] = i
|
||||
annotation['segmentation'] = mask.cpu().numpy()
|
||||
annotation['bbox'] = result.boxes.data[i]
|
||||
annotation['score'] = result.boxes.conf[i]
|
||||
annotation = {
|
||||
'id': i,
|
||||
'segmentation': mask.cpu().numpy(),
|
||||
'bbox': result.boxes.data[i],
|
||||
'score': result.boxes.conf[i]}
|
||||
annotation['area'] = annotation['segmentation'].sum()
|
||||
annotations.append(annotation)
|
||||
return annotations
|
||||
|
||||
def filter_masks(annotations): # filte the overlap mask
|
||||
@staticmethod
|
||||
def filter_masks(annotations): # filter the overlap mask
|
||||
annotations.sort(key=lambda x: x['area'], reverse=True)
|
||||
to_remove = set()
|
||||
for i in range(0, len(annotations)):
|
||||
for i in range(len(annotations)):
|
||||
a = annotations[i]
|
||||
for j in range(i + 1, len(annotations)):
|
||||
b = annotations[j]
|
||||
if i != j and j not in to_remove:
|
||||
# check if
|
||||
if b['area'] < a['area']:
|
||||
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
|
||||
to_remove.add(j)
|
||||
if i != j and j not in to_remove and b['area'] < a['area'] and \
|
||||
(a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
|
||||
to_remove.add(j)
|
||||
|
||||
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
|
||||
|
||||
def _get_bbox_from_mask(self, mask):
|
||||
@staticmethod
|
||||
def _get_bbox_from_mask(mask):
|
||||
mask = mask.astype(np.uint8)
|
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
x1, y1, w, h = cv2.boundingRect(contours[0])
|
||||
@ -105,7 +109,7 @@ class FastSAMPrompt:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
original_h = image.shape[0]
|
||||
original_w = image.shape[1]
|
||||
# for MacOS only
|
||||
# for macOS only
|
||||
# plt.switch_backend('TkAgg')
|
||||
plt.figure(figsize=(original_w / 100, original_h / 100))
|
||||
# Add subplot with no margin.
|
||||
@ -164,10 +168,9 @@ class FastSAMPrompt:
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for contour in contours:
|
||||
contour_all.append(contour)
|
||||
contour_all.extend(iter(contours))
|
||||
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
||||
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
|
||||
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
|
||||
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
||||
plt.imshow(contour_mask)
|
||||
|
||||
@ -212,7 +215,7 @@ class FastSAMPrompt:
|
||||
if random_color:
|
||||
color = np.random.random((msak_sum, 1, 1, 3))
|
||||
else:
|
||||
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
|
||||
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
|
||||
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
|
||||
visual = np.concatenate([color, transparency], axis=-1)
|
||||
mask_image = np.expand_dims(annotation, -1) * visual
|
||||
@ -267,8 +270,8 @@ class FastSAMPrompt:
|
||||
if random_color:
|
||||
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
|
||||
else:
|
||||
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
|
||||
30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
|
||||
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([30 / 255, 144 / 255, 1.0]).to(
|
||||
annotation.device)
|
||||
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
|
||||
visual = torch.cat([color, transparency], dim=-1)
|
||||
mask_image = torch.unsqueeze(annotation, -1) * visual
|
||||
@ -304,7 +307,7 @@ class FastSAMPrompt:
|
||||
@torch.no_grad()
|
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
||||
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
||||
tokenized_text = clip.tokenize([search_text]).to(device)
|
||||
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
||||
stacked_images = torch.stack(preprocessed_images)
|
||||
image_features = model.encode_image(stacked_images)
|
||||
text_features = model.encode_text(tokenized_text)
|
||||
@ -352,10 +355,10 @@ class FastSAMPrompt:
|
||||
int(bbox[1] * h / target_height),
|
||||
int(bbox[2] * w / target_width),
|
||||
int(bbox[3] * h / target_height), ]
|
||||
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
|
||||
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
|
||||
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
|
||||
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
|
||||
bbox[0] = max(round(bbox[0]), 0)
|
||||
bbox[1] = max(round(bbox[1]), 0)
|
||||
bbox[2] = min(round(bbox[2]), w)
|
||||
bbox[3] = min(round(bbox[3]), h)
|
||||
|
||||
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
||||
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
||||
@ -380,10 +383,7 @@ class FastSAMPrompt:
|
||||
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
||||
onemask = np.zeros((h, w))
|
||||
for i, annotation in enumerate(masks):
|
||||
if type(annotation) == dict:
|
||||
mask = annotation['segmentation']
|
||||
else:
|
||||
mask = annotation
|
||||
mask = annotation['segmentation'] if type(annotation) == dict else annotation
|
||||
for i, point in enumerate(points):
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
||||
onemask += mask
|
||||
@ -395,7 +395,7 @@ class FastSAMPrompt:
|
||||
def text_prompt(self, text):
|
||||
format_results = self._format_results(self.results[0], 0)
|
||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
||||
clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
|
||||
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
|
||||
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
||||
max_idx = scores.argsort()
|
||||
max_idx = max_idx[-1]
|
||||
|
@ -1,16 +1,20 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
||||
'''Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
"""
|
||||
Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
|
||||
Args:
|
||||
boxes: (n, 4)
|
||||
image_shape: (height, width)
|
||||
threshold: pixel threshold
|
||||
boxes: (n, 4)
|
||||
image_shape: (height, width)
|
||||
threshold: pixel threshold
|
||||
|
||||
Returns:
|
||||
adjusted_boxes: adjusted bounding boxes
|
||||
'''
|
||||
adjusted_boxes: adjusted bounding boxes
|
||||
"""
|
||||
|
||||
# Image dimensions
|
||||
h, w = image_shape
|
||||
@ -25,14 +29,16 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
||||
|
||||
|
||||
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
|
||||
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
|
||||
"""
|
||||
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
|
||||
|
||||
Args:
|
||||
box1: (4, )
|
||||
boxes: (n, 4)
|
||||
box1: (4, )
|
||||
boxes: (n, 4)
|
||||
|
||||
Returns:
|
||||
high_iou_indices: Indices of boxes with IoU > thres
|
||||
'''
|
||||
high_iou_indices: Indices of boxes with IoU > thres
|
||||
"""
|
||||
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
|
||||
# obtain coordinates for intersections
|
||||
x1 = torch.max(box1[0], boxes[:, 0])
|
||||
@ -53,11 +59,7 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals
|
||||
# compute the IoU
|
||||
iou = intersection / union # Should be shape (n, )
|
||||
if raw_output:
|
||||
if iou.numel() == 0:
|
||||
return 0
|
||||
return iou
|
||||
return 0 if iou.numel() == 0 else iou
|
||||
|
||||
# get indices of boxes with IoU > thres
|
||||
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
|
||||
|
||||
return high_iou_indices
|
||||
# return indices of boxes with IoU > thres
|
||||
return torch.nonzero(iou > iou_thres).flatten()
|
||||
|
@ -20,6 +20,7 @@ GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in (''
|
||||
[f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
|
||||
[f'yolo_nas_{k}.pt' for k in 'sml'] + \
|
||||
[f'sam_{k}.pt' for k in 'bl'] + \
|
||||
[f'FastSAM-{k}.pt' for k in 'sx'] + \
|
||||
[f'rtdetr-{k}.pt' for k in 'lx']
|
||||
GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES]
|
||||
|
||||
|
Reference in New Issue
Block a user