`ultralytics 8.0.144` fix SAM `predict()` results (#4027)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Laughing 2 years ago committed by GitHub
parent b3ddd9d09c
commit dbdea24955
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.143' __version__ = '8.0.144'
from ultralytics.hub import start from ultralytics.hub import start
from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models import RTDETR, SAM, YOLO

@ -29,7 +29,7 @@ class SAM(Model):
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024) overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
kwargs.update(overrides) kwargs.update(overrides)
prompts = dict(bboxes=bboxes, points=points, labels=labels) prompts = dict(bboxes=bboxes, points=points, labels=labels)
super().predict(source, stream, prompts=prompts, **kwargs) return super().predict(source, stream, prompts=prompts, **kwargs)
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs): def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection.""" """Calls the 'predict' function with given arguments to perform object detection."""

@ -294,7 +294,7 @@ class Predictor(BasePredictor):
def setup_model(self, model, verbose=True): def setup_model(self, model, verbose=True):
"""Set up YOLO model with specified thresholds and device.""" """Set up YOLO model with specified thresholds and device."""
device = select_device(self.args.device) device = select_device(self.args.device, verbose=verbose)
if model is None: if model is None:
model = build_sam(self.args.model) model = build_sam(self.args.model)
model.eval() model.eval()

Loading…
Cancel
Save