Add CLI support for SAM, RTDETR (#3253)

single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 4c2033d7c3
commit 58bccb1a9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,6 +17,7 @@ class SAM:
# Should raise AssertionError instead? # Should raise AssertionError instead?
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint') raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
self.model = build_sam(model) self.model = build_sam(model)
self.task = 'segment' # required
self.predictor = None # reuse predictor self.predictor = None # reuse predictor
def predict(self, source, stream=False, **kwargs): def predict(self, source, stream=False, **kwargs):

@ -366,9 +366,16 @@ def entrypoint(debug=''):
if model is None: if model is None:
model = 'yolov8n.pt' model = 'yolov8n.pt'
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.") LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
from ultralytics.yolo.engine.model import YOLO
overrides['model'] = model overrides['model'] = model
model = YOLO(model, task=task) if 'rtdetr' in model.lower(): # guess architecture
from ultralytics import RTDETR
model = RTDETR(model) # no task argument
elif 'sam' in model.lower():
from ultralytics import SAM
model = SAM(model)
else:
from ultralytics import YOLO
model = YOLO(model, task=task)
if isinstance(overrides.get('pretrained'), str): if isinstance(overrides.get('pretrained'), str):
model.load(overrides['pretrained']) model.load(overrides['pretrained'])

Loading…
Cancel
Save