Avoid CUDA round-trip for relevant export formats (#3727)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-07-14 20:38:31 +02:00
committed by GitHub
parent c5991d7cd8
commit 135a10f1fa
5 changed files with 40 additions and 32 deletions

View File

@ -18,7 +18,9 @@ from .build import build_sam
class Predictor(BasePredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides={}, _callbacks=None):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
if overrides is None:
overrides = {}
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
super().__init__(cfg, overrides, _callbacks)
# SAM needs retina_masks=True, or the results would be a mess.
@ -90,7 +92,7 @@ class Predictor(BasePredictor):
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
if all([i is None for i in [bboxes, points, masks]]):
if all(i is None for i in [bboxes, points, masks]):
return self.generate(im, *args, **kwargs)
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
@ -284,7 +286,7 @@ class Predictor(BasePredictor):
return pred_masks, pred_scores, pred_bboxes
def setup_model(self, model):
def setup_model(self, model, verbose=True):
"""Set up YOLO model with specified thresholds and device."""
device = select_device(self.args.device)
if model is None:
@ -306,7 +308,7 @@ class Predictor(BasePredictor):
# (N, 1, H, W), (N, 1)
pred_masks, pred_scores = preds[:2]
pred_bboxes = preds[2] if self.segment_all else None
names = dict(enumerate([str(i) for i in range(len(pred_masks))]))
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
results = []
for i, masks in enumerate([pred_masks]):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs