Avoid CUDA round-trip for relevant export formats (#3727)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -18,7 +18,9 @@ from .build import build_sam
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides={}, _callbacks=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
# SAM needs retina_masks=True, or the results would be a mess.
|
||||
@ -90,7 +92,7 @@ class Predictor(BasePredictor):
|
||||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
"""
|
||||
if all([i is None for i in [bboxes, points, masks]]):
|
||||
if all(i is None for i in [bboxes, points, masks]):
|
||||
return self.generate(im, *args, **kwargs)
|
||||
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
||||
|
||||
@ -284,7 +286,7 @@ class Predictor(BasePredictor):
|
||||
|
||||
return pred_masks, pred_scores, pred_bboxes
|
||||
|
||||
def setup_model(self, model):
|
||||
def setup_model(self, model, verbose=True):
|
||||
"""Set up YOLO model with specified thresholds and device."""
|
||||
device = select_device(self.args.device)
|
||||
if model is None:
|
||||
@ -306,7 +308,7 @@ class Predictor(BasePredictor):
|
||||
# (N, 1, H, W), (N, 1)
|
||||
pred_masks, pred_scores = preds[:2]
|
||||
pred_bboxes = preds[2] if self.segment_all else None
|
||||
names = dict(enumerate([str(i) for i in range(len(pred_masks))]))
|
||||
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
|
||||
results = []
|
||||
for i, masks in enumerate([pred_masks]):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
|
Reference in New Issue
Block a user