ultralytics 8.0.134
add MobileSAM support (#3474)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -131,6 +131,11 @@ class BasePredictor:
|
||||
img /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
return img
|
||||
|
||||
def inference(self, im, *args, **kwargs):
|
||||
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
|
||||
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""Pre-transform input image before inference.
|
||||
|
||||
@ -181,13 +186,13 @@ class BasePredictor:
|
||||
"""Post-processes predictions for an image and returns them."""
|
||||
return preds
|
||||
|
||||
def __call__(self, source=None, model=None, stream=False):
|
||||
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
|
||||
"""Performs inference on an image or stream."""
|
||||
self.stream = stream
|
||||
if stream:
|
||||
return self.stream_inference(source, model)
|
||||
return self.stream_inference(source, model, *args, **kwargs)
|
||||
else:
|
||||
return list(self.stream_inference(source, model)) # merge list of Result into one
|
||||
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
|
||||
|
||||
def predict_cli(self, source=None, model=None):
|
||||
"""Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
|
||||
@ -209,7 +214,7 @@ class BasePredictor:
|
||||
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
||||
|
||||
@smart_inference_mode()
|
||||
def stream_inference(self, source=None, model=None):
|
||||
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
||||
"""Streams real-time inference on camera feed and saves results to file."""
|
||||
if self.args.verbose:
|
||||
LOGGER.info('')
|
||||
@ -236,8 +241,6 @@ class BasePredictor:
|
||||
self.run_callbacks('on_predict_batch_start')
|
||||
self.batch = batch
|
||||
path, im0s, vid_cap, s = batch
|
||||
visualize = increment_path(self.save_dir / Path(path[0]).stem,
|
||||
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
|
||||
|
||||
# Preprocess
|
||||
with profilers[0]:
|
||||
@ -245,7 +248,7 @@ class BasePredictor:
|
||||
|
||||
# Inference
|
||||
with profilers[1]:
|
||||
preds = self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
preds = self.inference(im, *args, **kwargs)
|
||||
|
||||
# Postprocess
|
||||
with profilers[2]:
|
||||
|
@ -170,7 +170,7 @@ class Results(SimpleClass):
|
||||
font='Arial.ttf',
|
||||
pil=False,
|
||||
img=None,
|
||||
img_gpu=None,
|
||||
im_gpu=None,
|
||||
kpt_line=True,
|
||||
labels=True,
|
||||
boxes=True,
|
||||
@ -188,7 +188,7 @@ class Results(SimpleClass):
|
||||
font (str): The font to use for the text.
|
||||
pil (bool): Whether to return the image as a PIL Image.
|
||||
img (numpy.ndarray): Plot to another image. if not, plot to original image.
|
||||
img_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
|
||||
im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
|
||||
kpt_line (bool): Whether to draw lines connecting keypoints.
|
||||
labels (bool): Whether to plot the label of bounding boxes.
|
||||
boxes (bool): Whether to plot the bounding boxes.
|
||||
@ -226,12 +226,12 @@ class Results(SimpleClass):
|
||||
|
||||
# Plot Segment results
|
||||
if pred_masks and show_masks:
|
||||
if img_gpu is None:
|
||||
if im_gpu is None:
|
||||
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
||||
img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
|
||||
im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
|
||||
2, 0, 1).flip(0).contiguous() / 255
|
||||
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
|
||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu)
|
||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
|
||||
|
||||
# Plot Detect results
|
||||
if pred_boxes and show_boxes:
|
||||
|
Reference in New Issue
Block a user