From 58bccb1a9feba5e2d21b0c91cccee91beb6bafc8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 18 Jun 2023 22:47:33 +0200 Subject: [PATCH] Add CLI support for SAM, RTDETR (#3253) --- ultralytics/vit/sam/model.py | 1 + ultralytics/yolo/cfg/__init__.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ultralytics/vit/sam/model.py b/ultralytics/vit/sam/model.py index 420d6a6..83861f4 100644 --- a/ultralytics/vit/sam/model.py +++ b/ultralytics/vit/sam/model.py @@ -17,6 +17,7 @@ class SAM: # Should raise AssertionError instead? raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint') self.model = build_sam(model) + self.task = 'segment' # required self.predictor = None # reuse predictor def predict(self, source, stream=False, **kwargs): diff --git a/ultralytics/yolo/cfg/__init__.py b/ultralytics/yolo/cfg/__init__.py index cd919f7..746e458 100644 --- a/ultralytics/yolo/cfg/__init__.py +++ b/ultralytics/yolo/cfg/__init__.py @@ -366,9 +366,16 @@ def entrypoint(debug=''): if model is None: model = 'yolov8n.pt' LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.") - from ultralytics.yolo.engine.model import YOLO overrides['model'] = model - model = YOLO(model, task=task) + if 'rtdetr' in model.lower(): # guess architecture + from ultralytics import RTDETR + model = RTDETR(model) # no task argument + elif 'sam' in model.lower(): + from ultralytics import SAM + model = SAM(model) + else: + from ultralytics import YOLO + model = YOLO(model, task=task) if isinstance(overrides.get('pretrained'), str): model.load(overrides['pretrained'])