ultralytics 8.0.120
CLI support for SAM, RTDETR (#3273)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -5,6 +5,8 @@ RT-DETR model interface
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
@ -37,7 +39,7 @@ class RTDETR:
|
||||
self.task = 'detect'
|
||||
self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model
|
||||
|
||||
# Below added to allow export from yamls
|
||||
# Below added to allow export from YAMLs
|
||||
self.model.args = DEFAULT_CFG_DICT # attach args to model
|
||||
self.model.task = self.task
|
||||
|
||||
@ -125,6 +127,23 @@ class RTDETR:
|
||||
"""Get model info"""
|
||||
return model_info(self.model, verbose=verbose)
|
||||
|
||||
def _check_is_pytorch_model(self):
|
||||
"""
|
||||
Raises TypeError is model is not a PyTorch model
|
||||
"""
|
||||
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
|
||||
pt_module = isinstance(self.model, nn.Module)
|
||||
if not (pt_module or pt_str):
|
||||
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
|
||||
f'PyTorch models can be used to train, val, predict and export, i.e. '
|
||||
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
||||
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
||||
|
||||
def fuse(self):
|
||||
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
@smart_inference_mode()
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user