|
|
|
@ -18,7 +18,9 @@ from .build import build_sam
|
|
|
|
|
|
|
|
|
|
class Predictor(BasePredictor):
|
|
|
|
|
|
|
|
|
|
def __init__(self, cfg=DEFAULT_CFG, overrides={}, _callbacks=None):
|
|
|
|
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
|
|
|
if overrides is None:
|
|
|
|
|
overrides = {}
|
|
|
|
|
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
|
|
|
|
super().__init__(cfg, overrides, _callbacks)
|
|
|
|
|
# SAM needs retina_masks=True, or the results would be a mess.
|
|
|
|
@ -90,7 +92,7 @@ class Predictor(BasePredictor):
|
|
|
|
|
of masks and H=W=256. These low resolution logits can be passed to
|
|
|
|
|
a subsequent iteration as mask input.
|
|
|
|
|
"""
|
|
|
|
|
if all([i is None for i in [bboxes, points, masks]]):
|
|
|
|
|
if all(i is None for i in [bboxes, points, masks]):
|
|
|
|
|
return self.generate(im, *args, **kwargs)
|
|
|
|
|
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
|
|
|
|
|
|
|
|
|
@ -284,7 +286,7 @@ class Predictor(BasePredictor):
|
|
|
|
|
|
|
|
|
|
return pred_masks, pred_scores, pred_bboxes
|
|
|
|
|
|
|
|
|
|
def setup_model(self, model):
|
|
|
|
|
def setup_model(self, model, verbose=True):
|
|
|
|
|
"""Set up YOLO model with specified thresholds and device."""
|
|
|
|
|
device = select_device(self.args.device)
|
|
|
|
|
if model is None:
|
|
|
|
@ -306,7 +308,7 @@ class Predictor(BasePredictor):
|
|
|
|
|
# (N, 1, H, W), (N, 1)
|
|
|
|
|
pred_masks, pred_scores = preds[:2]
|
|
|
|
|
pred_bboxes = preds[2] if self.segment_all else None
|
|
|
|
|
names = dict(enumerate([str(i) for i in range(len(pred_masks))]))
|
|
|
|
|
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
|
|
|
|
|
results = []
|
|
|
|
|
for i, masks in enumerate([pred_masks]):
|
|
|
|
|
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
|
|
|
|