`ultralytics 8.0.127` add FastSAM model (#3390)

Co-authored-by: dingwenchao <12962189468@163.com>
Co-authored-by: 丁文超 <dingwenchao@dingwenchaodeMacBook-Pro.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
single_channel
Wenchao Ding 1 year ago committed by GitHub
parent 91905b4b0b
commit 400f3f72a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,9 +32,68 @@ FastSAM is designed to address the limitations of the [Segment Anything Model (S
## Usage
FastSAM is not yet available within the [`ultralytics` package](../quickstart.md), but it is available directly from the [https://github.com/CASIA-IVA-Lab/FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) repository. Here is a brief overview of the typical steps you might take to use FastSAM:
### Python API
### Installation
The FastSAM models are easy to integrate into your Python applications. Ultralytics provides a user-friendly Python API to streamline the process.
#### Predict Usage
To perform object detection on an image, use the `predict` method as shown below:
```python
from ultralytics import FastSAM
from ultralytics.yolo.fastsam import FastSAMPrompt
IMAGE_PATH = 'images/dog.jpg'
DEVICE = 'cpu'
model = FastSAM('FastSAM.pt')
results = model(
IMAGE_PATH,
device=DEVICE,
retina_masks=True,
imgsz=1024,
conf=0.4,
iou=0.9,
)
prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE)
# Everything prompt
ann = prompt_process.everything_prompt()
# Bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300])
# Text prompt
ann = prompt_process.text_prompt(text='a photo of a dog')
# Point prompt
# points default [[0,0]] [[x1,y1],[x2,y2]]
# point_label default [0] [1,0] 0:background, 1:foreground
ann = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1])
prompt_process.plot(annotations=ann, output='./')
```
This snippet demonstrates the simplicity of loading a pre-trained model and running a prediction on an image.
#### Val Usage
Validation of the model on a dataset can be done as follows:
```python
from ultralytics import FastSAM
model = FastSAM('FastSAM.pt')
results = model.val(data='coco8-seg.yaml)
```
Please note that FastSAM only supports detection and segmentation of a single class of object. This means it will recognize and segment all objects as the same class. Therefore, when preparing the dataset, you need to convert all object category IDs to 0.
### FastSAM official Usage
FastSAM is also available directly from the [https://github.com/CASIA-IVA-Lab/FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) repository. Here is a brief overview of the typical steps you might take to use FastSAM:
#### Installation
1. Clone the FastSAM repository:
```shell
@ -58,7 +117,7 @@ FastSAM is not yet available within the [`ultralytics` package](../quickstart.md
pip install git+https://github.com/openai/CLIP.git
```
### Example Usage
#### Example Usage
1. Download a [model checkpoint](https://drive.google.com/file/d/1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv/view?usp=sharing).
@ -101,4 +160,4 @@ We would like to acknowledge the FastSAM authors for their significant contribut
}
```
The original FastSAM paper can be found on [arXiv](https://arxiv.org/abs/2306.12156). The authors have made their work publicly available, and the codebase can be accessed on [GitHub](https://github.com/CASIA-IVA-Lab/FastSAM). We appreciate their efforts in advancing the field and making their work accessible to the broader community.
The original FastSAM paper can be found on [arXiv](https://arxiv.org/abs/2306.12156). The authors have made their work publicly available, and the codebase can be accessed on [GitHub](https://github.com/CASIA-IVA-Lab/FastSAM). We appreciate their efforts in advancing the field and making their work accessible to the broader community.

@ -1,13 +1,14 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.126'
__version__ = '8.0.127'
from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR
from ultralytics.vit.sam import SAM
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.fastsam import FastSAM
from ultralytics.yolo.nas import NAS
from ultralytics.yolo.utils.checks import check_yolo as checks
from ultralytics.yolo.utils.downloads import download
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'RTDETR', 'checks', 'start', 'download' # allow simpler import
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'RTDETR', 'checks', 'start', 'download', 'FastSAM' # allow simpler import

@ -0,0 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import FastSAM
from .predict import FastSAMPredictor
from .prompt import FastSAMPrompt
from .val import FastSAMValidator
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator'

@ -0,0 +1,104 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
FastSAM model interface.
Usage - Predict:
from ultralytics import FastSAM
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
"""
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
from ultralytics.yolo.utils.checks import check_imgsz
from ...yolo.utils.torch_utils import model_info, smart_inference_mode
from .predict import FastSAMPredictor
class FastSAM(YOLO):
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs
overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides['mode'] in ['track', 'predict']
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
self.predictor = FastSAMPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model, verbose=False)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""Function trains models but raises an error as FastSAM models do not support training."""
raise NotImplementedError("FastSAM models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='segment', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = FastSAM(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

@ -0,0 +1,51 @@
import torch
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.fastsam.utils import bbox_iou
from ultralytics.yolo.utils import DEFAULT_CFG, ops
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
class FastSAMPredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'segment'
def postprocess(self, preds, img, orig_imgs):
"""TODO: filter by classes."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes)
full_box = torch.zeros_like(p[0][0])
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
if critical_iou_index.numel() != 0:
full_box[0][4] = p[0][critical_iou_index][:, 4]
full_box[0][6:] = p[0][critical_iou_index][:, 6:]
p[0][critical_iou_index] = full_box
results = []
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
for i, pred in enumerate(p):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
if not len(pred): # save empty boxes
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
continue
if self.args.retina_masks:
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
return results

@ -0,0 +1,406 @@
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
try:
import clip # for linear_assignment
except (ImportError, AssertionError, AttributeError):
from ultralytics.yolo.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
import clip
class FastSAMPrompt:
def __init__(self, img_path, results, device='cuda') -> None:
# self.img_path = img_path
self.device = device
self.results = results
self.img_path = img_path
self.ori_img = cv2.imread(img_path)
def _segment_image(self, image, bbox):
image_array = np.array(image)
segmented_image_array = np.zeros_like(image_array)
x1, y1, x2, y2 = bbox
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
segmented_image = Image.fromarray(segmented_image_array)
black_image = Image.new('RGB', image.size, (255, 255, 255))
# transparency_mask = np.zeros_like((), dtype=np.uint8)
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
transparency_mask[y1:y2, x1:x2] = 255
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image
def _format_results(self, result, filter=0):
annotations = []
n = len(result.masks.data)
for i in range(n):
annotation = {}
mask = result.masks.data[i] == 1.0
if torch.sum(mask) < filter:
continue
annotation['id'] = i
annotation['segmentation'] = mask.cpu().numpy()
annotation['bbox'] = result.boxes.data[i]
annotation['score'] = result.boxes.conf[i]
annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation)
return annotations
def filter_masks(annotations): # filte the overlap mask
annotations.sort(key=lambda x: x['area'], reverse=True)
to_remove = set()
for i in range(0, len(annotations)):
a = annotations[i]
for j in range(i + 1, len(annotations)):
b = annotations[j]
if i != j and j not in to_remove:
# check if
if b['area'] < a['area']:
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
to_remove.add(j)
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
def _get_bbox_from_mask(self, mask):
mask = mask.astype(np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
x1, y1, w, h = cv2.boundingRect(contours[0])
x2, y2 = x1 + w, y1 + h
if len(contours) > 1:
for b in contours:
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
# 将多个bbox合并成一个
x1 = min(x1, x_t)
y1 = min(y1, y_t)
x2 = max(x2, x_t + w_t)
y2 = max(y2, y_t + h_t)
h = y2 - y1
w = x2 - x1
return [x1, y1, x2, y2]
def plot(self,
annotations,
output,
bbox=None,
points=None,
point_label=None,
mask_random_color=True,
better_quality=True,
retina=False,
withContours=True):
if isinstance(annotations[0], dict):
annotations = [annotation['segmentation'] for annotation in annotations]
result_name = os.path.basename(self.img_path)
image = self.ori_img
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_h = image.shape[0]
original_w = image.shape[1]
# for MacOS only
# plt.switch_backend('TkAgg')
plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin.
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(image)
if better_quality:
if isinstance(annotations[0], torch.Tensor):
annotations = np.array(annotations.cpu())
for i, mask in enumerate(annotations):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
if self.device == 'cpu':
annotations = np.array(annotations)
self.fast_show_mask(
annotations,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
else:
if isinstance(annotations[0], np.ndarray):
annotations = torch.from_numpy(annotations)
self.fast_show_mask_gpu(
annotations,
plt.gca(),
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if isinstance(annotations, torch.Tensor):
annotations = annotations.cpu().numpy()
if withContours:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(annotations):
if type(mask) == dict:
mask = mask['segmentation']
annotation = mask.astype(np.uint8)
if not retina:
annotation = cv2.resize(
annotation,
(original_w, original_h),
interpolation=cv2.INTER_NEAREST,
)
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
contour_all.append(contour)
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
save_path = output
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.axis('off')
fig = plt.gcf()
plt.draw()
try:
buf = fig.canvas.tostring_rgb()
except AttributeError:
fig.canvas.draw()
buf = fig.canvas.tostring_rgb()
cols, rows = fig.canvas.get_width_height()
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
# CPU post process
def fast_show_mask(
self,
annotation,
ax,
random_color=False,
bbox=None,
points=None,
pointlabel=None,
retinamask=True,
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
# 将annotation 按照面积 排序
areas = np.sum(annotation, axis=(1, 2))
sorted_indices = np.argsort(areas)
annotation = annotation[sorted_indices]
index = (annotation != 0).argmax(axis=0)
if random_color:
color = np.random.random((msak_sum, 1, 1, 3))
else:
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((height, weight, 4))
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# 使用向量化索引更新show的值
show[h_indices, w_indices, :] = mask_image[indices]
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
)
if not retinamask:
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show)
def fast_show_mask_gpu(
self,
annotation,
ax,
random_color=False,
bbox=None,
points=None,
pointlabel=None,
retinamask=True,
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
areas = torch.sum(annotation, dim=(1, 2))
sorted_indices = torch.argsort(areas, descending=False)
annotation = annotation[sorted_indices]
# 找每个位置第一个非零值下标
index = (annotation != 0).to(torch.long).argmax(dim=0)
if random_color:
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
else:
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
visual = torch.cat([color, transparency], dim=-1)
mask_image = torch.unsqueeze(annotation, -1) * visual
# 按index取数index指每个位置选哪个batch的数把mask_image转成一个batch的形式
show = torch.zeros((height, weight, 4)).to(annotation.device)
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# 使用向量化索引更新show的值
show[h_indices, w_indices, :] = mask_image[indices]
show_cpu = show.cpu().numpy()
if bbox is not None:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
)
if not retinamask:
show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show_cpu)
# clip
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = clip.tokenize([search_text]).to(device)
stacked_images = torch.stack(preprocessed_images)
image_features = model.encode_image(stacked_images)
text_features = model.encode_text(tokenized_text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
probs = 100.0 * image_features @ text_features.T
return probs[:, 0].softmax(dim=0)
def _crop_image(self, format_results):
image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB))
ori_w, ori_h = image.size
annotations = format_results
mask_h, mask_w = annotations[0]['segmentation'].shape
if ori_w != mask_w or ori_h != mask_h:
image = image.resize((mask_w, mask_h))
cropped_boxes = []
cropped_images = []
not_crop = []
filter_id = []
# annotations, _ = filter_masks(annotations)
# filter_id = list(_)
for _, mask in enumerate(annotations):
if np.sum(mask['segmentation']) <= 100:
filter_id.append(_)
continue
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
# cropped_boxes.append(segment_image(image,mask["segmentation"]))
cropped_images.append(bbox) # 保存裁剪的图片的bbox
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
def box_prompt(self, bbox):
assert (bbox[2] != 0 and bbox[3] != 0)
masks = self.results[0].masks.data
target_height = self.ori_img.shape[0]
target_width = self.ori_img.shape[1]
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
orig_masks_area = torch.sum(masks, dim=(1, 2))
union = bbox_area + orig_masks_area - masks_area
IoUs = masks_area / union
max_iou_index = torch.argmax(IoUs)
return np.array([masks[max_iou_index].cpu().numpy()])
def point_prompt(self, points, pointlabel): # numpy 处理
masks = self._format_results(self.results[0], 0)
target_height = self.ori_img.shape[0]
target_width = self.ori_img.shape[1]
h = masks[0]['segmentation'].shape[0]
w = masks[0]['segmentation'].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
for i, annotation in enumerate(masks):
if type(annotation) == dict:
mask = annotation['segmentation']
else:
mask = annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask += mask
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
onemask -= mask
onemask = onemask >= 1
return np.array([onemask])
def text_prompt(self, text):
format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort()
max_idx = max_idx[-1]
max_idx += sum(np.array(filter_id) <= int(max_idx))
return np.array([annotations[max_idx]['segmentation']])
def everything_prompt(self):
return self.results[0].masks.data

@ -0,0 +1,63 @@
import torch
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
'''Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes: (n, 4)
image_shape: (height, width)
threshold: pixel threshold
Returns:
adjusted_boxes: adjusted bounding boxes
'''
# Image dimensions
h, w = image_shape
# Adjust boxes
boxes[:, 0] = torch.where(boxes[:, 0] < threshold, 0, boxes[:, 0]) # x1
boxes[:, 1] = torch.where(boxes[:, 1] < threshold, 0, boxes[:, 1]) # y1
boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, w, boxes[:, 2]) # x2
boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, h, boxes[:, 3]) # y2
return boxes
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args:
box1: (4, )
boxes: (n, 4)
Returns:
high_iou_indices: Indices of boxes with IoU > thres
'''
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections
x1 = torch.max(box1[0], boxes[:, 0])
y1 = torch.max(box1[1], boxes[:, 1])
x2 = torch.min(box1[2], boxes[:, 2])
y2 = torch.min(box1[3], boxes[:, 3])
# compute the area of intersection
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
# compute the area of both individual boxes
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# compute the area of union
union = box1_area + box2_area - intersection
# compute the IoU
iou = intersection / union # Should be shape (n, )
if raw_output:
if iou.numel() == 0:
return 0
return iou
# get indices of boxes with IoU > thres
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
return high_iou_indices

@ -0,0 +1,244 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from multiprocessing.pool import ThreadPool
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.yolo.utils import LOGGER, NUM_THREADS, ops
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
from ultralytics.yolo.v8.detect import DetectionValidator
class FastSAMValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
batch = super().preprocess(batch)
batch['masks'] = batch['masks'].to(self.device).float()
return batch
def init_metrics(self, model):
"""Initialize metrics and select mask processing function based on save_json flag."""
super().init_metrics(model)
self.plot_masks = []
if self.args.save_json:
check_requirements('pycocotools>=2.0.6')
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
def get_desc(self):
"""Return a formatted description of evaluation metrics."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds):
"""Postprocesses YOLO predictions and returns output detections with proto."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc)
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
return p, proto
def update_metrics(self, preds, batch):
"""Metrics."""
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
idx = batch['batch_idx'] == si
cls = batch['cls'][idx]
bbox = batch['bboxes'][idx]
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
shape = batch['ori_shape'][si]
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
self.seen += 1
if npr == 0:
if nl:
self.stats.append((correct_bboxes, correct_masks, *torch.zeros(
(2, 0), device=self.device), cls.squeeze(-1)))
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
continue
# Masks
midx = [si] if self.args.overlap_mask else idx
gt_masks = batch['masks'][midx]
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn = pred.clone()
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
ratio_pad=batch['ratio_pad'][si]) # native-space pred
# Evaluate
if nl:
height, width = batch['img'].shape[2:]
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
(width, height, width, height), device=self.device) # target boxes
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
ratio_pad=batch['ratio_pad'][si]) # native-space labels
labelsn = torch.cat((cls, tbox), 1) # native-space labels
correct_bboxes = self._process_batch(predn, labelsn)
# TODO: maybe remove these `self.` arguments as they already are member variable
correct_masks = self._process_batch(predn,
labelsn,
pred_masks,
gt_masks,
overlap=self.args.overlap_mask,
masks=True)
if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn)
# Append correct_masks, correct_boxes, pconf, pcls, tcls
self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.args.plots and self.batch_i < 3:
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
# Save
if self.args.save_json:
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
shape,
ratio_pad=batch['ratio_pad'][si])
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def finalize_metrics(self, *args, **kwargs):
"""Sets speed and confusion matrix for evaluation metrics."""
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
Return correct prediction matrix
Arguments:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
Returns:
correct (array[N, 10]), for 10 IoU levels
"""
if masks:
if overlap:
nl = len(labels)
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
if gt_masks.shape[1:] != pred_masks.shape[1:]:
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
gt_masks = gt_masks.gt_(0.5)
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
else: # boxes
iou = box_iou(labels[:, 1:], detections[:, :4])
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
correct_class = labels[:, 0:1] == detections[:, 5]
for i in range(len(self.iouv)):
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
1).cpu().numpy() # [label, detect, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def plot_val_samples(self, batch, ni):
"""Plots validation samples with bounding box labels."""
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
plot_images(
batch['img'],
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):
"""Save one JSON result."""
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
from pycocotools.mask import encode # noqa
def single_encode(x):
"""Encode predicted masks as RLE and append results to jdict."""
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
rle['counts'] = rle['counts'].decode('utf-8')
return rle
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
pred_masks = np.transpose(pred_masks, (2, 0, 1))
with ThreadPool(NUM_THREADS) as pool:
rles = pool.map(single_encode, pred_masks)
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'score': round(p[4], 5),
'segmentation': rles[i]})
def eval_json(self, stats):
"""Return COCO-style object detection evaluation metrics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f'{x} file not found'
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
stats[self.metrics.keys[idx + 1]], stats[
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats
Loading…
Cancel
Save