ultralytics 8.0.127
add FastSAM model (#3390)
Co-authored-by: dingwenchao <12962189468@163.com> Co-authored-by: 丁文超 <dingwenchao@dingwenchaodeMacBook-Pro.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
8
ultralytics/yolo/fastsam/__init__.py
Normal file
8
ultralytics/yolo/fastsam/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .model import FastSAM
|
||||
from .predict import FastSAMPredictor
|
||||
from .prompt import FastSAMPrompt
|
||||
from .val import FastSAMValidator
|
||||
|
||||
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator'
|
104
ultralytics/yolo/fastsam/model.py
Normal file
104
ultralytics/yolo/fastsam/model.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
FastSAM model interface.
|
||||
|
||||
Usage - Predict:
|
||||
from ultralytics import FastSAM
|
||||
|
||||
model = FastSAM('last.pt')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
"""
|
||||
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
|
||||
from ...yolo.utils.torch_utils import model_info, smart_inference_mode
|
||||
from .predict import FastSAMPredictor
|
||||
|
||||
|
||||
class FastSAM(YOLO):
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||
**kwargs : Additional keyword arguments passed to the predictor.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||
assert overrides['mode'] in ['track', 'predict']
|
||||
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
||||
self.predictor = FastSAMPredictor(overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model, verbose=False)
|
||||
|
||||
return self.predictor(source, stream=stream)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""Function trains models but raises an error as FastSAM models do not support training."""
|
||||
raise NotImplementedError("FastSAM models don't support training")
|
||||
|
||||
def val(self, **kwargs):
|
||||
"""Run validation given dataset."""
|
||||
overrides = dict(task='segment', mode='val')
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
validator = FastSAM(args=args)
|
||||
validator(model=self.model)
|
||||
self.metrics = validator.metrics
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
overrides = dict(task='detect')
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'export'
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz:
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
return Exporter(overrides=args)(model=self.model)
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Raises error if object has no requested attribute."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
51
ultralytics/yolo/fastsam/predict.py
Normal file
51
ultralytics/yolo/fastsam/predict.py
Normal file
@ -0,0 +1,51 @@
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.fastsam.utils import bbox_iou
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ops
|
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
|
||||
|
||||
class FastSAMPredictor(DetectionPredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""TODO: filter by classes."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=len(self.model.names),
|
||||
classes=self.args.classes)
|
||||
full_box = torch.zeros_like(p[0][0])
|
||||
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
|
||||
full_box = full_box.view(1, -1)
|
||||
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
|
||||
if critical_iou_index.numel() != 0:
|
||||
full_box[0][4] = p[0][critical_iou_index][:, 4]
|
||||
full_box[0][6:] = p[0][critical_iou_index][:, 6:]
|
||||
p[0][critical_iou_index] = full_box
|
||||
results = []
|
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||
for i, pred in enumerate(p):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
path = self.batch[0]
|
||||
img_path = path[i] if isinstance(path, list) else path
|
||||
if not len(pred): # save empty boxes
|
||||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
|
||||
continue
|
||||
if self.args.retina_masks:
|
||||
if not isinstance(orig_imgs, torch.Tensor):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
||||
else:
|
||||
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||
if not isinstance(orig_imgs, torch.Tensor):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
results.append(
|
||||
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
||||
return results
|
406
ultralytics/yolo/fastsam/prompt.py
Normal file
406
ultralytics/yolo/fastsam/prompt.py
Normal file
@ -0,0 +1,406 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import clip # for linear_assignment
|
||||
|
||||
except (ImportError, AssertionError, AttributeError):
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
|
||||
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
|
||||
import clip
|
||||
|
||||
|
||||
class FastSAMPrompt:
|
||||
|
||||
def __init__(self, img_path, results, device='cuda') -> None:
|
||||
# self.img_path = img_path
|
||||
self.device = device
|
||||
self.results = results
|
||||
self.img_path = img_path
|
||||
self.ori_img = cv2.imread(img_path)
|
||||
|
||||
def _segment_image(self, image, bbox):
|
||||
image_array = np.array(image)
|
||||
segmented_image_array = np.zeros_like(image_array)
|
||||
x1, y1, x2, y2 = bbox
|
||||
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
||||
segmented_image = Image.fromarray(segmented_image_array)
|
||||
black_image = Image.new('RGB', image.size, (255, 255, 255))
|
||||
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
||||
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
|
||||
transparency_mask[y1:y2, x1:x2] = 255
|
||||
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
|
||||
black_image.paste(segmented_image, mask=transparency_mask_image)
|
||||
return black_image
|
||||
|
||||
def _format_results(self, result, filter=0):
|
||||
annotations = []
|
||||
n = len(result.masks.data)
|
||||
for i in range(n):
|
||||
annotation = {}
|
||||
mask = result.masks.data[i] == 1.0
|
||||
|
||||
if torch.sum(mask) < filter:
|
||||
continue
|
||||
annotation['id'] = i
|
||||
annotation['segmentation'] = mask.cpu().numpy()
|
||||
annotation['bbox'] = result.boxes.data[i]
|
||||
annotation['score'] = result.boxes.conf[i]
|
||||
annotation['area'] = annotation['segmentation'].sum()
|
||||
annotations.append(annotation)
|
||||
return annotations
|
||||
|
||||
def filter_masks(annotations): # filte the overlap mask
|
||||
annotations.sort(key=lambda x: x['area'], reverse=True)
|
||||
to_remove = set()
|
||||
for i in range(0, len(annotations)):
|
||||
a = annotations[i]
|
||||
for j in range(i + 1, len(annotations)):
|
||||
b = annotations[j]
|
||||
if i != j and j not in to_remove:
|
||||
# check if
|
||||
if b['area'] < a['area']:
|
||||
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
|
||||
to_remove.add(j)
|
||||
|
||||
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
|
||||
|
||||
def _get_bbox_from_mask(self, mask):
|
||||
mask = mask.astype(np.uint8)
|
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
x1, y1, w, h = cv2.boundingRect(contours[0])
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
if len(contours) > 1:
|
||||
for b in contours:
|
||||
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
|
||||
# 将多个bbox合并成一个
|
||||
x1 = min(x1, x_t)
|
||||
y1 = min(y1, y_t)
|
||||
x2 = max(x2, x_t + w_t)
|
||||
y2 = max(y2, y_t + h_t)
|
||||
h = y2 - y1
|
||||
w = x2 - x1
|
||||
return [x1, y1, x2, y2]
|
||||
|
||||
def plot(self,
|
||||
annotations,
|
||||
output,
|
||||
bbox=None,
|
||||
points=None,
|
||||
point_label=None,
|
||||
mask_random_color=True,
|
||||
better_quality=True,
|
||||
retina=False,
|
||||
withContours=True):
|
||||
if isinstance(annotations[0], dict):
|
||||
annotations = [annotation['segmentation'] for annotation in annotations]
|
||||
result_name = os.path.basename(self.img_path)
|
||||
image = self.ori_img
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
original_h = image.shape[0]
|
||||
original_w = image.shape[1]
|
||||
# for MacOS only
|
||||
# plt.switch_backend('TkAgg')
|
||||
plt.figure(figsize=(original_w / 100, original_h / 100))
|
||||
# Add subplot with no margin.
|
||||
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||
plt.margins(0, 0)
|
||||
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
||||
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
||||
|
||||
plt.imshow(image)
|
||||
if better_quality:
|
||||
if isinstance(annotations[0], torch.Tensor):
|
||||
annotations = np.array(annotations.cpu())
|
||||
for i, mask in enumerate(annotations):
|
||||
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
||||
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
|
||||
if self.device == 'cpu':
|
||||
annotations = np.array(annotations)
|
||||
self.fast_show_mask(
|
||||
annotations,
|
||||
plt.gca(),
|
||||
random_color=mask_random_color,
|
||||
bbox=bbox,
|
||||
points=points,
|
||||
pointlabel=point_label,
|
||||
retinamask=retina,
|
||||
target_height=original_h,
|
||||
target_width=original_w,
|
||||
)
|
||||
else:
|
||||
if isinstance(annotations[0], np.ndarray):
|
||||
annotations = torch.from_numpy(annotations)
|
||||
self.fast_show_mask_gpu(
|
||||
annotations,
|
||||
plt.gca(),
|
||||
random_color=mask_random_color,
|
||||
bbox=bbox,
|
||||
points=points,
|
||||
pointlabel=point_label,
|
||||
retinamask=retina,
|
||||
target_height=original_h,
|
||||
target_width=original_w,
|
||||
)
|
||||
if isinstance(annotations, torch.Tensor):
|
||||
annotations = annotations.cpu().numpy()
|
||||
if withContours:
|
||||
contour_all = []
|
||||
temp = np.zeros((original_h, original_w, 1))
|
||||
for i, mask in enumerate(annotations):
|
||||
if type(mask) == dict:
|
||||
mask = mask['segmentation']
|
||||
annotation = mask.astype(np.uint8)
|
||||
if not retina:
|
||||
annotation = cv2.resize(
|
||||
annotation,
|
||||
(original_w, original_h),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for contour in contours:
|
||||
contour_all.append(contour)
|
||||
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
||||
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
|
||||
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
||||
plt.imshow(contour_mask)
|
||||
|
||||
save_path = output
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
plt.axis('off')
|
||||
fig = plt.gcf()
|
||||
plt.draw()
|
||||
|
||||
try:
|
||||
buf = fig.canvas.tostring_rgb()
|
||||
except AttributeError:
|
||||
fig.canvas.draw()
|
||||
buf = fig.canvas.tostring_rgb()
|
||||
cols, rows = fig.canvas.get_width_height()
|
||||
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
|
||||
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
|
||||
|
||||
# CPU post process
|
||||
def fast_show_mask(
|
||||
self,
|
||||
annotation,
|
||||
ax,
|
||||
random_color=False,
|
||||
bbox=None,
|
||||
points=None,
|
||||
pointlabel=None,
|
||||
retinamask=True,
|
||||
target_height=960,
|
||||
target_width=960,
|
||||
):
|
||||
msak_sum = annotation.shape[0]
|
||||
height = annotation.shape[1]
|
||||
weight = annotation.shape[2]
|
||||
# 将annotation 按照面积 排序
|
||||
areas = np.sum(annotation, axis=(1, 2))
|
||||
sorted_indices = np.argsort(areas)
|
||||
annotation = annotation[sorted_indices]
|
||||
|
||||
index = (annotation != 0).argmax(axis=0)
|
||||
if random_color:
|
||||
color = np.random.random((msak_sum, 1, 1, 3))
|
||||
else:
|
||||
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
|
||||
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
|
||||
visual = np.concatenate([color, transparency], axis=-1)
|
||||
mask_image = np.expand_dims(annotation, -1) * visual
|
||||
|
||||
show = np.zeros((height, weight, 4))
|
||||
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
|
||||
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
||||
# 使用向量化索引更新show的值
|
||||
show[h_indices, w_indices, :] = mask_image[indices]
|
||||
if bbox is not None:
|
||||
x1, y1, x2, y2 = bbox
|
||||
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
|
||||
# draw point
|
||||
if points is not None:
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
s=20,
|
||||
c='y',
|
||||
)
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
s=20,
|
||||
c='m',
|
||||
)
|
||||
|
||||
if not retinamask:
|
||||
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
||||
ax.imshow(show)
|
||||
|
||||
def fast_show_mask_gpu(
|
||||
self,
|
||||
annotation,
|
||||
ax,
|
||||
random_color=False,
|
||||
bbox=None,
|
||||
points=None,
|
||||
pointlabel=None,
|
||||
retinamask=True,
|
||||
target_height=960,
|
||||
target_width=960,
|
||||
):
|
||||
msak_sum = annotation.shape[0]
|
||||
height = annotation.shape[1]
|
||||
weight = annotation.shape[2]
|
||||
areas = torch.sum(annotation, dim=(1, 2))
|
||||
sorted_indices = torch.argsort(areas, descending=False)
|
||||
annotation = annotation[sorted_indices]
|
||||
# 找每个位置第一个非零值下标
|
||||
index = (annotation != 0).to(torch.long).argmax(dim=0)
|
||||
if random_color:
|
||||
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
|
||||
else:
|
||||
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
|
||||
30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
|
||||
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
|
||||
visual = torch.cat([color, transparency], dim=-1)
|
||||
mask_image = torch.unsqueeze(annotation, -1) * visual
|
||||
# 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
|
||||
show = torch.zeros((height, weight, 4)).to(annotation.device)
|
||||
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
|
||||
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
||||
# 使用向量化索引更新show的值
|
||||
show[h_indices, w_indices, :] = mask_image[indices]
|
||||
show_cpu = show.cpu().numpy()
|
||||
if bbox is not None:
|
||||
x1, y1, x2, y2 = bbox
|
||||
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
|
||||
# draw point
|
||||
if points is not None:
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
s=20,
|
||||
c='y',
|
||||
)
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
s=20,
|
||||
c='m',
|
||||
)
|
||||
if not retinamask:
|
||||
show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
||||
ax.imshow(show_cpu)
|
||||
|
||||
# clip
|
||||
@torch.no_grad()
|
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
||||
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
||||
tokenized_text = clip.tokenize([search_text]).to(device)
|
||||
stacked_images = torch.stack(preprocessed_images)
|
||||
image_features = model.encode_image(stacked_images)
|
||||
text_features = model.encode_text(tokenized_text)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
probs = 100.0 * image_features @ text_features.T
|
||||
return probs[:, 0].softmax(dim=0)
|
||||
|
||||
def _crop_image(self, format_results):
|
||||
|
||||
image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB))
|
||||
ori_w, ori_h = image.size
|
||||
annotations = format_results
|
||||
mask_h, mask_w = annotations[0]['segmentation'].shape
|
||||
if ori_w != mask_w or ori_h != mask_h:
|
||||
image = image.resize((mask_w, mask_h))
|
||||
cropped_boxes = []
|
||||
cropped_images = []
|
||||
not_crop = []
|
||||
filter_id = []
|
||||
# annotations, _ = filter_masks(annotations)
|
||||
# filter_id = list(_)
|
||||
for _, mask in enumerate(annotations):
|
||||
if np.sum(mask['segmentation']) <= 100:
|
||||
filter_id.append(_)
|
||||
continue
|
||||
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
|
||||
cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
|
||||
# cropped_boxes.append(segment_image(image,mask["segmentation"]))
|
||||
cropped_images.append(bbox) # 保存裁剪的图片的bbox
|
||||
|
||||
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
||||
|
||||
def box_prompt(self, bbox):
|
||||
|
||||
assert (bbox[2] != 0 and bbox[3] != 0)
|
||||
masks = self.results[0].masks.data
|
||||
target_height = self.ori_img.shape[0]
|
||||
target_width = self.ori_img.shape[1]
|
||||
h = masks.shape[1]
|
||||
w = masks.shape[2]
|
||||
if h != target_height or w != target_width:
|
||||
bbox = [
|
||||
int(bbox[0] * w / target_width),
|
||||
int(bbox[1] * h / target_height),
|
||||
int(bbox[2] * w / target_width),
|
||||
int(bbox[3] * h / target_height), ]
|
||||
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
|
||||
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
|
||||
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
|
||||
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
|
||||
|
||||
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
||||
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
||||
|
||||
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
|
||||
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
||||
|
||||
union = bbox_area + orig_masks_area - masks_area
|
||||
IoUs = masks_area / union
|
||||
max_iou_index = torch.argmax(IoUs)
|
||||
|
||||
return np.array([masks[max_iou_index].cpu().numpy()])
|
||||
|
||||
def point_prompt(self, points, pointlabel): # numpy 处理
|
||||
|
||||
masks = self._format_results(self.results[0], 0)
|
||||
target_height = self.ori_img.shape[0]
|
||||
target_width = self.ori_img.shape[1]
|
||||
h = masks[0]['segmentation'].shape[0]
|
||||
w = masks[0]['segmentation'].shape[1]
|
||||
if h != target_height or w != target_width:
|
||||
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
||||
onemask = np.zeros((h, w))
|
||||
for i, annotation in enumerate(masks):
|
||||
if type(annotation) == dict:
|
||||
mask = annotation['segmentation']
|
||||
else:
|
||||
mask = annotation
|
||||
for i, point in enumerate(points):
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
||||
onemask += mask
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
|
||||
onemask -= mask
|
||||
onemask = onemask >= 1
|
||||
return np.array([onemask])
|
||||
|
||||
def text_prompt(self, text):
|
||||
format_results = self._format_results(self.results[0], 0)
|
||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
||||
clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
|
||||
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
||||
max_idx = scores.argsort()
|
||||
max_idx = max_idx[-1]
|
||||
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
||||
return np.array([annotations[max_idx]['segmentation']])
|
||||
|
||||
def everything_prompt(self):
|
||||
return self.results[0].masks.data
|
63
ultralytics/yolo/fastsam/utils.py
Normal file
63
ultralytics/yolo/fastsam/utils.py
Normal file
@ -0,0 +1,63 @@
|
||||
import torch
|
||||
|
||||
|
||||
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
||||
'''Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
Args:
|
||||
boxes: (n, 4)
|
||||
image_shape: (height, width)
|
||||
threshold: pixel threshold
|
||||
|
||||
Returns:
|
||||
adjusted_boxes: adjusted bounding boxes
|
||||
'''
|
||||
|
||||
# Image dimensions
|
||||
h, w = image_shape
|
||||
|
||||
# Adjust boxes
|
||||
boxes[:, 0] = torch.where(boxes[:, 0] < threshold, 0, boxes[:, 0]) # x1
|
||||
boxes[:, 1] = torch.where(boxes[:, 1] < threshold, 0, boxes[:, 1]) # y1
|
||||
boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, w, boxes[:, 2]) # x2
|
||||
boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, h, boxes[:, 3]) # y2
|
||||
|
||||
return boxes
|
||||
|
||||
|
||||
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
|
||||
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
|
||||
Args:
|
||||
box1: (4, )
|
||||
boxes: (n, 4)
|
||||
|
||||
Returns:
|
||||
high_iou_indices: Indices of boxes with IoU > thres
|
||||
'''
|
||||
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
|
||||
# obtain coordinates for intersections
|
||||
x1 = torch.max(box1[0], boxes[:, 0])
|
||||
y1 = torch.max(box1[1], boxes[:, 1])
|
||||
x2 = torch.min(box1[2], boxes[:, 2])
|
||||
y2 = torch.min(box1[3], boxes[:, 3])
|
||||
|
||||
# compute the area of intersection
|
||||
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
|
||||
|
||||
# compute the area of both individual boxes
|
||||
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
# compute the area of union
|
||||
union = box1_area + box2_area - intersection
|
||||
|
||||
# compute the IoU
|
||||
iou = intersection / union # Should be shape (n, )
|
||||
if raw_output:
|
||||
if iou.numel() == 0:
|
||||
return 0
|
||||
return iou
|
||||
|
||||
# get indices of boxes with IoU > thres
|
||||
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
|
||||
|
||||
return high_iou_indices
|
244
ultralytics/yolo/fastsam/val.py
Normal file
244
ultralytics/yolo/fastsam/val.py
Normal file
@ -0,0 +1,244 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, NUM_THREADS, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
|
||||
|
||||
class FastSAMValidator(DetectionValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses batch by converting masks to float and sending to device."""
|
||||
batch = super().preprocess(batch)
|
||||
batch['masks'] = batch['masks'].to(self.device).float()
|
||||
return batch
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize metrics and select mask processing function based on save_json flag."""
|
||||
super().init_metrics(model)
|
||||
self.plot_masks = []
|
||||
if self.args.save_json:
|
||||
check_requirements('pycocotools>=2.0.6')
|
||||
self.process = ops.process_mask_upsample # more accurate
|
||||
else:
|
||||
self.process = ops.process_mask # faster
|
||||
|
||||
def get_desc(self):
|
||||
"""Return a formatted description of evaluation metrics."""
|
||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
|
||||
'R', 'mAP50', 'mAP50-95)')
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Postprocesses YOLO predictions and returns output detections with proto."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=True,
|
||||
agnostic=self.args.single_cls,
|
||||
max_det=self.args.max_det,
|
||||
nc=self.nc)
|
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||
return p, proto
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Metrics."""
|
||||
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
||||
idx = batch['batch_idx'] == si
|
||||
cls = batch['cls'][idx]
|
||||
bbox = batch['bboxes'][idx]
|
||||
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
||||
shape = batch['ori_shape'][si]
|
||||
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
self.seen += 1
|
||||
|
||||
if npr == 0:
|
||||
if nl:
|
||||
self.stats.append((correct_bboxes, correct_masks, *torch.zeros(
|
||||
(2, 0), device=self.device), cls.squeeze(-1)))
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
|
||||
continue
|
||||
|
||||
# Masks
|
||||
midx = [si] if self.args.overlap_mask else idx
|
||||
gt_masks = batch['masks'][midx]
|
||||
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
|
||||
|
||||
# Predictions
|
||||
if self.args.single_cls:
|
||||
pred[:, 5] = 0
|
||||
predn = pred.clone()
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space pred
|
||||
|
||||
# Evaluate
|
||||
if nl:
|
||||
height, width = batch['img'].shape[2:]
|
||||
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
|
||||
(width, height, width, height), device=self.device) # target boxes
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space labels
|
||||
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||
correct_bboxes = self._process_batch(predn, labelsn)
|
||||
# TODO: maybe remove these `self.` arguments as they already are member variable
|
||||
correct_masks = self._process_batch(predn,
|
||||
labelsn,
|
||||
pred_masks,
|
||||
gt_masks,
|
||||
overlap=self.args.overlap_mask,
|
||||
masks=True)
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(predn, labelsn)
|
||||
|
||||
# Append correct_masks, correct_boxes, pconf, pcls, tcls
|
||||
self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
|
||||
|
||||
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
||||
if self.args.plots and self.batch_i < 3:
|
||||
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
||||
|
||||
# Save
|
||||
if self.args.save_json:
|
||||
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
||||
shape,
|
||||
ratio_pad=batch['ratio_pad'][si])
|
||||
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
|
||||
# if self.args.save_txt:
|
||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Sets speed and confusion matrix for evaluation metrics."""
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
Arguments:
|
||||
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
||||
labels (array[M, 5]), class, x1, y1, x2, y2
|
||||
Returns:
|
||||
correct (array[N, 10]), for 10 IoU levels
|
||||
"""
|
||||
if masks:
|
||||
if overlap:
|
||||
nl = len(labels)
|
||||
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
|
||||
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
||||
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
||||
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
||||
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
|
||||
gt_masks = gt_masks.gt_(0.5)
|
||||
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
||||
else: # boxes
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
|
||||
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
|
||||
correct_class = labels[:, 0:1] == detections[:, 5]
|
||||
for i in range(len(self.iouv)):
|
||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
|
||||
1).cpu().numpy() # [label, detect, iou]
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
# matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
correct[matches[:, 1].astype(int), i] = True
|
||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plots validation samples with bounding box labels."""
|
||||
plot_images(batch['img'],
|
||||
batch['batch_idx'],
|
||||
batch['cls'].squeeze(-1),
|
||||
batch['bboxes'],
|
||||
batch['masks'],
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots batch predictions with masks and bounding boxes."""
|
||||
plot_images(
|
||||
batch['img'],
|
||||
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
|
||||
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot) # pred
|
||||
self.plot_masks.clear()
|
||||
|
||||
def pred_to_json(self, predn, filename, pred_masks):
|
||||
"""Save one JSON result."""
|
||||
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
|
||||
from pycocotools.mask import encode # noqa
|
||||
|
||||
def single_encode(x):
|
||||
"""Encode predicted masks as RLE and append results to jdict."""
|
||||
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
|
||||
rle['counts'] = rle['counts'].decode('utf-8')
|
||||
return rle
|
||||
|
||||
stem = Path(filename).stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||
pred_masks = np.transpose(pred_masks, (2, 0, 1))
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
rles = pool.map(single_encode, pred_masks)
|
||||
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
|
||||
self.jdict.append({
|
||||
'image_id': image_id,
|
||||
'category_id': self.class_map[int(p[5])],
|
||||
'bbox': [round(x, 3) for x in b],
|
||||
'score': round(p[4], 5),
|
||||
'segmentation': rles[i]})
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Return COCO-style object detection evaluation metrics."""
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||
check_requirements('pycocotools>=2.0.6')
|
||||
from pycocotools.coco import COCO # noqa
|
||||
from pycocotools.cocoeval import COCOeval # noqa
|
||||
|
||||
for x in anno_json, pred_json:
|
||||
assert x.is_file(), f'{x} file not found'
|
||||
anno = COCO(str(anno_json)) # init annotations api
|
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
|
||||
if self.is_coco:
|
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
||||
eval.evaluate()
|
||||
eval.accumulate()
|
||||
eval.summarize()
|
||||
idx = i * 4 + 2
|
||||
stats[self.metrics.keys[idx + 1]], stats[
|
||||
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'pycocotools unable to run: {e}')
|
||||
return stats
|
Reference in New Issue
Block a user