|
|
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torchvision
|
|
|
|
|
|
|
|
from ultralytics.data.augment import LetterBox
|
|
|
|
from ultralytics.engine.predictor import BasePredictor
|
|
|
|
from ultralytics.engine.results import Results
|
|
|
|
from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
|
from ultralytics.utils.torch_utils import select_device
|
|
|
|
|
|
|
|
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
|
|
|
|
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
|
|
|
|
from .build import build_sam
|
|
|
|
|
|
|
|
|
|
|
|
class Predictor(BasePredictor):
|
|
|
|
|
|
|
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
|
|
if overrides is None:
|
|
|
|
overrides = {}
|
|
|
|
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
|
|
|
super().__init__(cfg, overrides, _callbacks)
|
|
|
|
# SAM needs retina_masks=True, or the results would be a mess.
|
|
|
|
self.args.retina_masks = True
|
|
|
|
# Args for set_image
|
|
|
|
self.im = None
|
|
|
|
self.features = None
|
|
|
|
# Args for set_prompts
|
|
|
|
self.prompts = {}
|
|
|
|
# Args for segment everything
|
|
|
|
self.segment_all = False
|
|
|
|
|
|
|
|
def preprocess(self, im):
|
|
|
|
"""Prepares input image before inference.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
|
|
|
"""
|
|
|
|
if self.im is not None:
|
|
|
|
return self.im
|
|
|
|
not_tensor = not isinstance(im, torch.Tensor)
|
|
|
|
if not_tensor:
|
|
|
|
im = np.stack(self.pre_transform(im))
|
|
|
|
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
|
|
|
im = np.ascontiguousarray(im) # contiguous
|
|
|
|
im = torch.from_numpy(im)
|
|
|
|
|
|
|
|
img = im.to(self.device)
|
|
|
|
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
|
|
|
if not_tensor:
|
|
|
|
img = (img - self.mean) / self.std
|
|
|
|
return img
|
|
|
|
|
|
|
|
def pre_transform(self, im):
|
|
|
|
"""Pre-transform input image before inference.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
|
|
|
|
|
|
|
Return: A list of transformed imgs.
|
|
|
|
"""
|
|
|
|
assert len(im) == 1, 'SAM model has not supported batch inference yet!'
|
|
|
|
return [LetterBox(self.args.imgsz, auto=False, center=False)(image=x) for x in im]
|
|
|
|
|
|
|
|
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
Predict masks for the given input prompts, using the currently set image.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
|
|
|
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
|
|
|
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
|
|
|
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
|
|
|
1 indicates a foreground point and 0 indicates a background point.
|
|
|
|
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
|
|
|
coming from a previous prediction iteration. Has form (N, H, W), where
|
|
|
|
for SAM, H=W=256.
|
|
|
|
multimask_output (bool): If true, the model will return three masks.
|
|
|
|
For ambiguous input prompts (such as a single click), this will often
|
|
|
|
produce better masks than a single prediction. If only a single
|
|
|
|
mask is needed, the model's predicted quality score can be used
|
|
|
|
to select the best mask. For non-ambiguous prompts, such as multiple
|
|
|
|
input prompts, multimask_output=False can give better results.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(np.ndarray): The output masks in CxHxW format, where C is the
|
|
|
|
number of masks, and (H, W) is the original image size.
|
|
|
|
(np.ndarray): An array of length C containing the model's
|
|
|
|
predictions for the quality of each mask.
|
|
|
|
(np.ndarray): An array of shape CxHxW, where C is the number
|
|
|
|
of masks and H=W=256. These low resolution logits can be passed to
|
|
|
|
a subsequent iteration as mask input.
|
|
|
|
"""
|
|
|
|
# Get prompts from self.prompts first
|
|
|
|
bboxes = self.prompts.pop('bboxes', bboxes)
|
|
|
|
points = self.prompts.pop('points', points)
|
|
|
|
masks = self.prompts.pop('masks', masks)
|
|
|
|
if all(i is None for i in [bboxes, points, masks]):
|
|
|
|
return self.generate(im, *args, **kwargs)
|
|
|
|
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
|
|
|
|
|
|
|
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
|
|
|
"""
|
|
|
|
Predict masks for the given input prompts, using the currently set image.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
|
|
|
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
|
|
|
|
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
|
|
|
|
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
|
|
|
|
1 indicates a foreground point and 0 indicates a background point.
|
|
|
|
masks (np.ndarray, None): A low resolution mask input to the model, typically
|
|
|
|
coming from a previous prediction iteration. Has form (N, H, W), where
|
|
|
|
for SAM, H=W=256.
|
|
|
|
multimask_output (bool): If true, the model will return three masks.
|
|
|
|
For ambiguous input prompts (such as a single click), this will often
|
|
|
|
produce better masks than a single prediction. If only a single
|
|
|
|
mask is needed, the model's predicted quality score can be used
|
|
|
|
to select the best mask. For non-ambiguous prompts, such as multiple
|
|
|
|
input prompts, multimask_output=False can give better results.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(np.ndarray): The output masks in CxHxW format, where C is the
|
|
|
|
number of masks, and (H, W) is the original image size.
|
|
|
|
(np.ndarray): An array of length C containing the model's
|
|
|
|
predictions for the quality of each mask.
|
|
|
|
(np.ndarray): An array of shape CxHxW, where C is the number
|
|
|
|
of masks and H=W=256. These low resolution logits can be passed to
|
|
|
|
a subsequent iteration as mask input.
|
|
|
|
"""
|
|
|
|
features = self.model.image_encoder(im) if self.features is None else self.features
|
|
|
|
|
|
|
|
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
|
|
|
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
|
|
|
# Transform input prompts
|
|
|
|
if points is not None:
|
|
|
|
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
|
|
|
|
points = points[None] if points.ndim == 1 else points
|
|
|
|
# Assuming labels are all positive if users don't pass labels.
|
|
|
|
if labels is None:
|
|
|
|
labels = np.ones(points.shape[0])
|
|
|
|
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
|
|
|
points *= r
|
|
|
|
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
|
|
|
points, labels = points[:, None, :], labels[:, None]
|
|
|
|
if bboxes is not None:
|
|
|
|
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
|
|
|
|
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
|
|
|
bboxes *= r
|
|
|
|
if masks is not None:
|
|
|
|
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device)
|
|
|
|
masks = masks[:, None, :, :]
|
|
|
|
|
|
|
|
points = (points, labels) if points is not None else None
|
|
|
|
# Embed prompts
|
|
|
|
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
|
|
|
points=points,
|
|
|
|
boxes=bboxes,
|
|
|
|
masks=masks,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Predict masks
|
|
|
|
pred_masks, pred_scores = self.model.mask_decoder(
|
|
|
|
image_embeddings=features,
|
|
|
|
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
|
|
|
sparse_prompt_embeddings=sparse_embeddings,
|
|
|
|
dense_prompt_embeddings=dense_embeddings,
|
|
|
|
multimask_output=multimask_output,
|
|
|
|
)
|
|
|
|
|
|
|
|
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
|
|
|
# `d` could be 1 or 3 depends on `multimask_output`.
|
|
|
|
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
|
|
|
|
|
|
|
def generate(self,
|
|
|
|
im,
|
|
|
|
crop_n_layers=0,
|
|
|
|
crop_overlap_ratio=512 / 1500,
|
|
|
|
crop_downscale_factor=1,
|
|
|
|
point_grids=None,
|
|
|
|
points_stride=32,
|
|
|
|
points_batch_size=64,
|
|
|
|
conf_thres=0.88,
|
|
|
|
stability_score_thresh=0.95,
|
|
|
|
stability_score_offset=0.95,
|
|
|
|
crop_nms_thresh=0.7):
|
|
|
|
"""Segment the whole image.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
|
|
|
crop_n_layers (int): If >0, mask prediction will be run again on
|
|
|
|
crops of the image. Sets the number of layers to run, where each
|
|
|
|
layer has 2**i_layer number of image crops.
|
|
|
|
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
|
|
|
In the first crop layer, crops will overlap by this fraction of
|
|
|
|
the image length. Later layers with more crops scale down this overlap.
|
|
|
|
crop_downscale_factor (int): The number of points-per-side
|
|
|
|
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
|
|
|
point_grids (list(np.ndarray), None): A list over explicit grids
|
|
|
|
of points used for sampling, normalized to [0,1]. The nth grid in the
|
|
|
|
list is used in the nth crop layer. Exclusive with points_per_side.
|
|
|
|
points_stride (int, None): The number of points to be sampled
|
|
|
|
along one side of the image. The total number of points is
|
|
|
|
points_per_side**2. If None, 'point_grids' must provide explicit
|
|
|
|
point sampling.
|
|
|
|
points_batch_size (int): Sets the number of points run simultaneously
|
|
|
|
by the model. Higher numbers may be faster but use more GPU memory.
|
|
|
|
conf_thres (float): A filtering threshold in [0,1], using the
|
|
|
|
model's predicted mask quality.
|
|
|
|
stability_score_thresh (float): A filtering threshold in [0,1], using
|
|
|
|
the stability of the mask under changes to the cutoff used to binarize
|
|
|
|
the model's mask predictions.
|
|
|
|
stability_score_offset (float): The amount to shift the cutoff when
|
|
|
|
calculated the stability score.
|
|
|
|
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
|
|
|
suppression to filter duplicate masks between different crops.
|
|
|
|
"""
|
|
|
|
self.segment_all = True
|
|
|
|
ih, iw = im.shape[2:]
|
|
|
|
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
|
|
|
|
if point_grids is None:
|
|
|
|
point_grids = build_all_layer_point_grids(
|
|
|
|
points_stride,
|
|
|
|
crop_n_layers,
|
|
|
|
crop_downscale_factor,
|
|
|
|
)
|
|
|
|
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
|
|
|
|
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
|
|
|
|
x1, y1, x2, y2 = crop_region
|
|
|
|
w, h = x2 - x1, y2 - y1
|
|
|
|
area = torch.tensor(w * h, device=im.device)
|
|
|
|
points_scale = np.array([[w, h]]) # w, h
|
|
|
|
# Crop image and interpolate to input size
|
|
|
|
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
|
|
|
|
# (num_points, 2)
|
|
|
|
points_for_image = point_grids[layer_idx] * points_scale
|
|
|
|
crop_masks, crop_scores, crop_bboxes = [], [], []
|
|
|
|
for (points, ) in batch_iterator(points_batch_size, points_for_image):
|
|
|
|
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
|
|
|
|
# Interpolate predicted masks to input size
|
|
|
|
pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
|
|
|
|
idx = pred_score > conf_thres
|
|
|
|
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
|
|
|
|
|
|
|
stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
|
|
|
|
stability_score_offset)
|
|
|
|
idx = stability_score > stability_score_thresh
|
|
|
|
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
|
|
|
|
# Bool type is much more memory-efficient.
|
|
|
|
pred_mask = pred_mask > self.model.mask_threshold
|
|
|
|
# (N, 4)
|
|
|
|
pred_bbox = batched_mask_to_box(pred_mask).float()
|
|
|
|
keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
|
|
|
|
if not torch.all(keep_mask):
|
|
|
|
pred_bbox = pred_bbox[keep_mask]
|
|
|
|
pred_mask = pred_mask[keep_mask]
|
|
|
|
pred_score = pred_score[keep_mask]
|
|
|
|
|
|
|
|
crop_masks.append(pred_mask)
|
|
|
|
crop_bboxes.append(pred_bbox)
|
|
|
|
crop_scores.append(pred_score)
|
|
|
|
|
|
|
|
# Do nms within this crop
|
|
|
|
crop_masks = torch.cat(crop_masks)
|
|
|
|
crop_bboxes = torch.cat(crop_bboxes)
|
|
|
|
crop_scores = torch.cat(crop_scores)
|
|
|
|
keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS
|
|
|
|
crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)
|
|
|
|
crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)
|
|
|
|
crop_scores = crop_scores[keep]
|
|
|
|
|
|
|
|
pred_masks.append(crop_masks)
|
|
|
|
pred_bboxes.append(crop_bboxes)
|
|
|
|
pred_scores.append(crop_scores)
|
|
|
|
region_areas.append(area.expand(len(crop_masks)))
|
|
|
|
|
|
|
|
pred_masks = torch.cat(pred_masks)
|
|
|
|
pred_bboxes = torch.cat(pred_bboxes)
|
|
|
|
pred_scores = torch.cat(pred_scores)
|
|
|
|
region_areas = torch.cat(region_areas)
|
|
|
|
|
|
|
|
# Remove duplicate masks between crops
|
|
|
|
if len(crop_regions) > 1:
|
|
|
|
scores = 1 / region_areas
|
|
|
|
keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
|
|
|
|
pred_masks = pred_masks[keep]
|
|
|
|
pred_bboxes = pred_bboxes[keep]
|
|
|
|
pred_scores = pred_scores[keep]
|
|
|
|
|
|
|
|
return pred_masks, pred_scores, pred_bboxes
|
|
|
|
|
|
|
|
def setup_model(self, model, verbose=True):
|
|
|
|
"""Set up YOLO model with specified thresholds and device."""
|
|
|
|
device = select_device(self.args.device)
|
|
|
|
if model is None:
|
|
|
|
model = build_sam(self.args.model)
|
|
|
|
model.eval()
|
|
|
|
self.model = model.to(device)
|
|
|
|
self.device = device
|
|
|
|
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
|
|
|
|
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
|
|
|
|
# TODO: Temporary settings for compatibility
|
|
|
|
self.model.pt = False
|
|
|
|
self.model.triton = False
|
|
|
|
self.model.stride = 32
|
|
|
|
self.model.fp16 = False
|
|
|
|
self.done_warmup = True
|
|
|
|
|
|
|
|
def postprocess(self, preds, img, orig_imgs):
|
|
|
|
"""Postprocesses inference output predictions to create detection masks for objects."""
|
|
|
|
# (N, 1, H, W), (N, 1)
|
|
|
|
pred_masks, pred_scores = preds[:2]
|
|
|
|
pred_bboxes = preds[2] if self.segment_all else None
|
|
|
|
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
|
|
|
|
results = []
|
|
|
|
for i, masks in enumerate([pred_masks]):
|
|
|
|
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
|
|
|
if pred_bboxes is not None:
|
|
|
|
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
|
|
|
|
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
|
|
|
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
|
|
|
|
|
|
|
|
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
|
|
|
|
masks = masks > self.model.mask_threshold # to bool
|
|
|
|
path = self.batch[0]
|
|
|
|
img_path = path[i] if isinstance(path, list) else path
|
|
|
|
results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
|
|
|
|
# Reset segment-all mode.
|
|
|
|
self.segment_all = False
|
|
|
|
return results
|
|
|
|
|
|
|
|
def setup_source(self, source):
|
|
|
|
"""Sets up source and inference mode."""
|
|
|
|
if source is not None:
|
|
|
|
super().setup_source(source)
|
|
|
|
|
|
|
|
def set_image(self, image):
|
|
|
|
"""Set image in advance.
|
|
|
|
Args:
|
|
|
|
|
|
|
|
image (str | np.ndarray): image file path or np.ndarray image by cv2.
|
|
|
|
"""
|
|
|
|
if self.model is None:
|
|
|
|
model = build_sam(self.args.model)
|
|
|
|
self.setup_model(model)
|
|
|
|
self.setup_source(image)
|
|
|
|
assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
|
|
|
|
for batch in self.dataset:
|
|
|
|
im = self.preprocess(batch[1])
|
|
|
|
self.features = self.model.image_encoder(im)
|
|
|
|
self.im = im
|
|
|
|
break
|
|
|
|
|
|
|
|
def set_prompts(self, prompts):
|
|
|
|
"""Set prompts in advance."""
|
|
|
|
self.prompts = prompts
|
|
|
|
|
|
|
|
def reset_image(self):
|
|
|
|
self.im = None
|
|
|
|
self.features = None
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
|
|
|
"""
|
|
|
|
Removes small disconnected regions and holes in masks, then reruns
|
|
|
|
box NMS to remove any new duplicates. Requires open-cv as a dependency.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
masks (torch.Tensor): Masks, (N, H, W).
|
|
|
|
min_area (int): Minimum area threshold.
|
|
|
|
nms_thresh (float): NMS threshold.
|
|
|
|
"""
|
|
|
|
if len(masks) == 0:
|
|
|
|
return masks
|
|
|
|
|
|
|
|
# Filter small disconnected regions and holes
|
|
|
|
new_masks = []
|
|
|
|
scores = []
|
|
|
|
for mask in masks:
|
|
|
|
mask = mask.cpu().numpy()
|
|
|
|
mask, changed = remove_small_regions(mask, min_area, mode='holes')
|
|
|
|
unchanged = not changed
|
|
|
|
mask, changed = remove_small_regions(mask, min_area, mode='islands')
|
|
|
|
unchanged = unchanged and not changed
|
|
|
|
|
|
|
|
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
|
|
|
# Give score=0 to changed masks and score=1 to unchanged masks
|
|
|
|
# so NMS will prefer ones that didn't need postprocessing
|
|
|
|
scores.append(float(unchanged))
|
|
|
|
|
|
|
|
# Recalculate boxes and remove any new duplicates
|
|
|
|
new_masks = torch.cat(new_masks, dim=0)
|
|
|
|
boxes = batched_mask_to_box(new_masks)
|
|
|
|
keep = torchvision.ops.nms(
|
|
|
|
boxes.float(),
|
|
|
|
torch.as_tensor(scores),
|
|
|
|
nms_thresh,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Only recalculate masks for masks that have changed
|
|
|
|
for i in keep:
|
|
|
|
if scores[i] == 0.0:
|
|
|
|
masks[i] = new_masks[i]
|
|
|
|
|
|
|
|
return masks[keep]
|