ultralytics 8.0.134 add MobileSAM support (#3474)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Chaoning Zhang
2023-07-13 20:25:56 +08:00
committed by GitHub
parent c55a98ab8e
commit 201e69e4e4
32 changed files with 1472 additions and 841 deletions

View File

@ -1,8 +1,6 @@
from pathlib import Path
from ultralytics import YOLO
from ultralytics.vit.sam import PromptPredictor, build_sam
from ultralytics.yolo.utils.torch_utils import select_device
from ultralytics import SAM, YOLO
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
@ -16,33 +14,21 @@ def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='',
output_dir (str | None | optional): Directory to save the annotated results.
Defaults to a 'labels' folder in the same directory as 'data'.
"""
device = select_device(device)
det_model = YOLO(det_model)
sam_model = build_sam(sam_model)
det_model.to(device)
sam_model.to(device)
sam_model = SAM(sam_model)
if not output_dir:
output_dir = Path(str(data)).parent / 'labels'
Path(output_dir).mkdir(exist_ok=True, parents=True)
prompt_predictor = PromptPredictor(sam_model)
det_results = det_model(data, stream=True)
det_results = det_model(data, stream=True, device=device)
for result in det_results:
boxes = result.boxes.xyxy # Boxes object for bbox outputs
class_ids = result.boxes.cls.int().tolist() # noqa
if len(class_ids):
prompt_predictor.set_image(result.orig_img)
masks, _, _ = prompt_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=prompt_predictor.transform.apply_boxes_torch(boxes, result.orig_shape[:2]),
multimask_output=False,
)
result.update(masks=masks.squeeze(1))
segments = result.masks.xyn # noqa
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
segments = sam_results[0].masks.xyn # noqa
with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f:
for i in range(len(segments)):

View File

@ -538,13 +538,14 @@ class RandomFlip:
class LetterBox:
"""Resize image and padding for detection, instance segmentation, pose."""
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):
"""Initialize LetterBox object with specific parameters."""
self.new_shape = new_shape
self.auto = auto
self.scaleFill = scaleFill
self.scaleup = scaleup
self.stride = stride
self.center = center # Put the image in the middle or top-left
def __call__(self, labels=None, image=None):
"""Return updated labels and image with added border."""
@ -572,15 +573,16 @@ class LetterBox:
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if self.center:
dw /= 2 # divide padding into 2 sides
dh /= 2
if labels.get('ratio_pad'):
labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh)) # for evaluation
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=(114, 114, 114)) # add border

View File

@ -131,6 +131,11 @@ class BasePredictor:
img /= 255 # 0 - 255 to 0.0 - 1.0
return img
def inference(self, im, *args, **kwargs):
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
return self.model(im, augment=self.args.augment, visualize=visualize)
def pre_transform(self, im):
"""Pre-transform input image before inference.
@ -181,13 +186,13 @@ class BasePredictor:
"""Post-processes predictions for an image and returns them."""
return preds
def __call__(self, source=None, model=None, stream=False):
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
"""Performs inference on an image or stream."""
self.stream = stream
if stream:
return self.stream_inference(source, model)
return self.stream_inference(source, model, *args, **kwargs)
else:
return list(self.stream_inference(source, model)) # merge list of Result into one
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
def predict_cli(self, source=None, model=None):
"""Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
@ -209,7 +214,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
@smart_inference_mode()
def stream_inference(self, source=None, model=None):
def stream_inference(self, source=None, model=None, *args, **kwargs):
"""Streams real-time inference on camera feed and saves results to file."""
if self.args.verbose:
LOGGER.info('')
@ -236,8 +241,6 @@ class BasePredictor:
self.run_callbacks('on_predict_batch_start')
self.batch = batch
path, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path[0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
# Preprocess
with profilers[0]:
@ -245,7 +248,7 @@ class BasePredictor:
# Inference
with profilers[1]:
preds = self.model(im, augment=self.args.augment, visualize=visualize)
preds = self.inference(im, *args, **kwargs)
# Postprocess
with profilers[2]:

View File

@ -170,7 +170,7 @@ class Results(SimpleClass):
font='Arial.ttf',
pil=False,
img=None,
img_gpu=None,
im_gpu=None,
kpt_line=True,
labels=True,
boxes=True,
@ -188,7 +188,7 @@ class Results(SimpleClass):
font (str): The font to use for the text.
pil (bool): Whether to return the image as a PIL Image.
img (numpy.ndarray): Plot to another image. if not, plot to original image.
img_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
kpt_line (bool): Whether to draw lines connecting keypoints.
labels (bool): Whether to plot the label of bounding boxes.
boxes (bool): Whether to plot the bounding boxes.
@ -226,12 +226,12 @@ class Results(SimpleClass):
# Plot Segment results
if pred_masks and show_masks:
if img_gpu is None:
if im_gpu is None:
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
2, 0, 1).flip(0).contiguous() / 255
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu)
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
# Plot Detect results
if pred_boxes and show_boxes:

View File

@ -8,12 +8,12 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes: (n, 4)
image_shape: (height, width)
threshold: pixel threshold
boxes (torch.Tensor): (n, 4)
image_shape (tuple): (height, width)
threshold (int): pixel threshold
Returns:
adjusted_boxes: adjusted bounding boxes
adjusted_boxes (torch.Tensor): adjusted bounding boxes
"""
# Image dimensions
@ -32,11 +32,11 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args:
box1: (4, )
boxes: (n, 4)
box1 (torch.Tensor): (4, )
boxes (torch.Tensor): (n, 4)
Returns:
high_iou_indices: Indices of boxes with IoU > thres
high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
"""
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections

View File

@ -21,7 +21,8 @@ GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in (''
[f'yolo_nas_{k}.pt' for k in 'sml'] + \
[f'sam_{k}.pt' for k in 'bl'] + \
[f'FastSAM-{k}.pt' for k in 'sx'] + \
[f'rtdetr-{k}.pt' for k in 'lx']
[f'rtdetr-{k}.pt' for k in 'lx'] + \
['mobile_sam.pt']
GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES]

View File

@ -20,6 +20,7 @@ def _ntuple(n):
return parse
to_2tuple = _ntuple(2)
to_4tuple = _ntuple(4)
# `xyxy` means left top and right bottom

View File

@ -92,7 +92,7 @@ def segment2box(segment, width=640, height=640):
4, dtype=segment.dtype) # xyxy
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
"""
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
(img1_shape) to the shape of a different image (img0_shape).
@ -103,6 +103,8 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
img0_shape (tuple): the shape of the target image, in the format of (height, width).
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
calculated based on the size difference between the two images.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
Returns:
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
@ -115,8 +117,9 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
gain = ratio_pad[0][0]
pad = ratio_pad[1]
boxes[..., [0, 2]] -= pad[0] # x padding
boxes[..., [1, 3]] -= pad[1] # y padding
if padding:
boxes[..., [0, 2]] -= pad[0] # x padding
boxes[..., [1, 3]] -= pad[1] # y padding
boxes[..., :4] /= gain
clip_boxes(boxes, img0_shape)
return boxes
@ -552,7 +555,7 @@ def crop_mask(masks, boxes):
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box
Args:
masks (torch.Tensor): [h, w, n] tensor of masks
masks (torch.Tensor): [n, h, w] tensor of masks
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
Returns:
@ -634,18 +637,36 @@ def process_mask_native(protos, masks_in, bboxes, shape):
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2 # wh padding
top, left = int(pad[1]), int(pad[0]) # y, x
bottom, right = int(mh - pad[1]), int(mw - pad[0])
masks = masks[:, top:bottom, left:right]
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
masks = scale_masks(masks[None], shape)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5)
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False):
def scale_masks(masks, shape, padding=True):
"""
Rescale segment masks to shape.
Args:
masks (torch.Tensor): (N, C, H, W).
shape (tuple): Height and width.
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
"""
mh, mw = masks.shape[2:]
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding
if padding:
pad[0] /= 2
pad[1] /= 2
top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x
bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
masks = masks[..., top:bottom, left:right]
masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW
return masks
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
"""
Rescale segment coordinates (xyxy) from img1_shape to img0_shape
@ -655,6 +676,8 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
img0_shape (tuple): the shape of the image that the segmentation is being applied to
ratio_pad (tuple): the ratio of the image size to the padded image size.
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
rescaling.
Returns:
coords (torch.Tensor): the segmented image.
@ -666,8 +689,9 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
gain = ratio_pad[0][0]
pad = ratio_pad[1]
coords[..., 0] -= pad[0] # x padding
coords[..., 1] -= pad[1] # y padding
if padding:
coords[..., 0] -= pad[0] # x padding
coords[..., 1] -= pad[1] # y padding
coords[..., 0] /= gain
coords[..., 1] /= gain
clip_coords(coords, img0_shape)