From 91905b4b0b7b48f3ff0bf7b4d433c15a9450142c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 5 Jul 2023 18:25:27 +0200 Subject: [PATCH] Pass export=True to RTDETRDecoder (#3550) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/engine/exporter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index dac2fcf..7fa6cf4 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -59,7 +59,7 @@ from pathlib import Path import torch from ultralytics.nn.autobackend import check_class_names -from ultralytics.nn.modules import C2f, Detect, Segment +from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder from ultralytics.nn.tasks import DetectionModel, SegmentationModel from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, __version__, callbacks, colorstr, @@ -157,13 +157,13 @@ class Exporter: # Load PyTorch model self.device = select_device('cpu' if self.args.device is None else self.args.device) + + # Checks + model.names = check_class_names(model.names) if self.args.half and onnx and self.device.type == 'cpu': LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0') self.args.half = False assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.' - - # Checks - model.names = check_class_names(model.names) self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size if self.args.optimize: assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu' @@ -185,7 +185,7 @@ class Exporter: model.float() model = model.fuse() for k, m in model.named_modules(): - if isinstance(m, (Detect, Segment)): + if isinstance(m, (Detect, RTDETRDecoder)): # Segment and Pose use Detect base class m.dynamic = self.args.dynamic m.export = True m.format = self.args.format