ultralytics 8.0.134 add MobileSAM support (#3474)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Chaoning Zhang
2023-07-13 20:25:56 +08:00
committed by GitHub
parent c55a98ab8e
commit 201e69e4e4
32 changed files with 1472 additions and 841 deletions

View File

@ -131,6 +131,11 @@ class BasePredictor:
img /= 255 # 0 - 255 to 0.0 - 1.0
return img
def inference(self, im, *args, **kwargs):
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
return self.model(im, augment=self.args.augment, visualize=visualize)
def pre_transform(self, im):
"""Pre-transform input image before inference.
@ -181,13 +186,13 @@ class BasePredictor:
"""Post-processes predictions for an image and returns them."""
return preds
def __call__(self, source=None, model=None, stream=False):
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
"""Performs inference on an image or stream."""
self.stream = stream
if stream:
return self.stream_inference(source, model)
return self.stream_inference(source, model, *args, **kwargs)
else:
return list(self.stream_inference(source, model)) # merge list of Result into one
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
def predict_cli(self, source=None, model=None):
"""Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
@ -209,7 +214,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
@smart_inference_mode()
def stream_inference(self, source=None, model=None):
def stream_inference(self, source=None, model=None, *args, **kwargs):
"""Streams real-time inference on camera feed and saves results to file."""
if self.args.verbose:
LOGGER.info('')
@ -236,8 +241,6 @@ class BasePredictor:
self.run_callbacks('on_predict_batch_start')
self.batch = batch
path, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path[0]).stem,
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
# Preprocess
with profilers[0]:
@ -245,7 +248,7 @@ class BasePredictor:
# Inference
with profilers[1]:
preds = self.model(im, augment=self.args.augment, visualize=visualize)
preds = self.inference(im, *args, **kwargs)
# Postprocess
with profilers[2]:

View File

@ -170,7 +170,7 @@ class Results(SimpleClass):
font='Arial.ttf',
pil=False,
img=None,
img_gpu=None,
im_gpu=None,
kpt_line=True,
labels=True,
boxes=True,
@ -188,7 +188,7 @@ class Results(SimpleClass):
font (str): The font to use for the text.
pil (bool): Whether to return the image as a PIL Image.
img (numpy.ndarray): Plot to another image. if not, plot to original image.
img_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
kpt_line (bool): Whether to draw lines connecting keypoints.
labels (bool): Whether to plot the label of bounding boxes.
boxes (bool): Whether to plot the bounding boxes.
@ -226,12 +226,12 @@ class Results(SimpleClass):
# Plot Segment results
if pred_masks and show_masks:
if img_gpu is None:
if im_gpu is None:
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
2, 0, 1).flip(0).contiguous() / 255
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu)
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
# Plot Detect results
if pred_boxes and show_boxes: