You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
399 lines
18 KiB
399 lines
18 KiB
2 years ago
# Ultralytics YOLO 🚀, AGPL-3.0 license
2 years ago
import numpy as np
import torch
2 years ago
import torch.nn.functional as F
import torchvision
2 years ago
2 years ago
from import LetterBox
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
from ultralytics.utils.torch_utils import select_device
2 years ago
2 years ago
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
from .build import build_sam
2 years ago
class Predictor(BasePredictor):
2 years ago
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
if overrides is None:
overrides = {}
2 years ago
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
super().__init__(cfg, overrides, _callbacks)
# SAM needs retina_masks=True, or the results would be a mess.
self.args.retina_masks = True
# Args for set_image
| = None
self.features = None
# Args for segment everything
self.segment_all = False
2 years ago
def preprocess(self, im):
2 years ago
"""Prepares input image before inference.
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
if is not None:
not_tensor = not isinstance(im, torch.Tensor)
if not_tensor:
im = np.stack(self.pre_transform(im))
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im)
img =
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
if not_tensor:
img = (img - self.mean) / self.std
return img
def pre_transform(self, im):
"""Pre-transform input image before inference.
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
Return: A list of transformed imgs.
assert len(im) == 1, 'SAM model has not supported batch inference yet!'
return [LetterBox(self.args.imgsz, auto=False, center=False)(image=x) for x in im]
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
Predict masks for the given input prompts, using the currently set image.
im (torch.Tensor): The preprocessed image, (N, C, H, W).
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
1 indicates a foreground point and 0 indicates a background point.
masks (np.ndarray, None): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form (N, H, W), where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
2 years ago
if all(i is None for i in [bboxes, points, masks]):
2 years ago
return self.generate(im, *args, **kwargs)
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
Predict masks for the given input prompts, using the currently set image.
im (torch.Tensor): The preprocessed image, (N, C, H, W).
bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
labels (np.ndarray | List, None): (N, ), labels for the point prompts.
1 indicates a foreground point and 0 indicates a background point.
masks (np.ndarray, None): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form (N, H, W), where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
features = self.model.image_encoder(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
# Transform input prompts
if points is not None:
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
points = points[None] if points.ndim == 1 else points
# Assuming labels are all positive if users don't pass labels.
if labels is None:
labels = np.ones(points.shape[0])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
points *= r
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None, :], labels[:, None]
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bboxes *= r
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device)
masks = masks[:, None, :, :]
points = (points, labels) if points is not None else None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
# Predict masks
pred_masks, pred_scores = self.model.mask_decoder(
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def generate(self,
crop_overlap_ratio=512 / 1500,
"""Segment the whole image.
im (torch.Tensor): The preprocessed image, (N, C, H, W).
crop_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_downscale_factor (int): The number of points-per-side
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
point_grids (list(np.ndarray), None): A list over explicit grids
of points used for sampling, normalized to [0,1]. The nth grid in the
list is used in the nth crop layer. Exclusive with points_per_side.
points_stride (int, None): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2. If None, 'point_grids' must provide explicit
point sampling.
points_batch_size (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
conf_thres (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops.
self.segment_all = True
ih, iw = im.shape[2:]
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
if point_grids is None:
point_grids = build_all_layer_point_grids(
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
for crop_region, layer_idx in zip(crop_regions, layer_idxs):
x1, y1, x2, y2 = crop_region
w, h = x2 - x1, y2 - y1
area = torch.tensor(w * h, device=im.device)
points_scale = np.array([[w, h]]) # w, h
# Crop image and interpolate to input size
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
# (num_points, 2)
points_for_image = point_grids[layer_idx] * points_scale
crop_masks, crop_scores, crop_bboxes = [], [], []
for (points, ) in batch_iterator(points_batch_size, points_for_image):
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
# Interpolate predicted masks to input size
pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
idx = pred_score > conf_thres
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
idx = stability_score > stability_score_thresh
pred_mask, pred_score = pred_mask[idx], pred_score[idx]
# Bool type is much more memory-efficient.
pred_mask = pred_mask > self.model.mask_threshold
# (N, 4)
pred_bbox = batched_mask_to_box(pred_mask).float()
keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
if not torch.all(keep_mask):
pred_bbox = pred_bbox[keep_mask]
pred_mask = pred_mask[keep_mask]
pred_score = pred_score[keep_mask]
# Do nms within this crop
crop_masks =
crop_bboxes =
crop_scores =
keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS
crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)
crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)
crop_scores = crop_scores[keep]
pred_masks =
pred_bboxes =
pred_scores =
region_areas =
# Remove duplicate masks between crops
if len(crop_regions) > 1:
scores = 1 / region_areas
keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
pred_masks = pred_masks[keep]
pred_bboxes = pred_bboxes[keep]
pred_scores = pred_scores[keep]
return pred_masks, pred_scores, pred_bboxes
2 years ago
2 years ago
def setup_model(self, model, verbose=True):
2 years ago
"""Set up YOLO model with specified thresholds and device."""
device = select_device(self.args.device)
2 years ago
if model is None:
model = build_sam(self.args.model)
2 years ago
2 years ago
self.model =
2 years ago
self.device = device
2 years ago
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
2 years ago
# TODO: Temporary settings for compatibility
| = False
self.model.triton = False
self.model.stride = 32
self.model.fp16 = False
self.done_warmup = True
2 years ago
def postprocess(self, preds, img, orig_imgs):
2 years ago
"""Postprocesses inference output predictions to create detection masks for objects."""
2 years ago
# (N, 1, H, W), (N, 1)
pred_masks, pred_scores = preds[:2]
pred_bboxes = preds[2] if self.segment_all else None
2 years ago
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
2 years ago
results = []
2 years ago
for i, masks in enumerate([pred_masks]):
2 years ago
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
2 years ago
if pred_bboxes is not None:
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
pred_bboxes =[pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
masks = masks > self.model.mask_threshold # to bool
2 years ago
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
2 years ago
results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
# Reset segment-all mode.
self.segment_all = False
2 years ago
return results
2 years ago
def setup_source(self, source):
"""Sets up source and inference mode."""
if source is not None:
def set_image(self, image):
"""Set image in advance.
image (str | np.ndarray): image file path or np.ndarray image by cv2.
if self.model is None:
model = build_sam(self.args.model)
assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
for batch in self.dataset:
im = self.preprocess(batch[1])
self.features = self.model.image_encoder(im)
| = im
def reset_image(self):
| = None
self.features = None
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
Removes small disconnected regions and holes in masks, then reruns
box NMS to remove any new duplicates. Requires open-cv as a dependency.
masks (torch.Tensor): Masks, (N, H, W).
min_area (int): Minimum area threshold.
nms_thresh (float): NMS threshold.
if len(masks) == 0:
return masks
# Filter small disconnected regions and holes
new_masks = []
scores = []
for mask in masks:
mask = mask.cpu().numpy()
mask, changed = remove_small_regions(mask, min_area, mode='holes')
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode='islands')
unchanged = unchanged and not changed
# Give score=0 to changed masks and score=1 to unchanged masks
# so NMS will prefer ones that didn't need postprocessing
# Recalculate boxes and remove any new duplicates
new_masks =, dim=0)
boxes = batched_mask_to_box(new_masks)
keep = torchvision.ops.nms(
# Only recalculate masks for masks that have changed
for i in keep:
if scores[i] == 0.0:
masks[i] = new_masks[i]
return masks[keep]