From 1c753cbce6493654d5c40575d935babdcd463d11 Mon Sep 17 00:00:00 2001 From: DennisJ <106725464+DennisJcy@users.noreply.github.com> Date: Sat, 12 Aug 2023 23:27:52 +0800 Subject: [PATCH] ORT_CPP add CUDA FP16 inference (#4320) Co-authored-by: Glenn Jocher --- .../YOLOv8-ONNXRuntime-CPP/CMakeLists.txt | 8 +++- examples/YOLOv8-ONNXRuntime-CPP/README.md | 10 ++--- examples/YOLOv8-ONNXRuntime-CPP/inference.cpp | 45 ++++++++++++++----- examples/YOLOv8-ONNXRuntime-CPP/inference.h | 6 ++- examples/YOLOv8-ONNXRuntime-CPP/main.cpp | 12 ++--- 5 files changed, 57 insertions(+), 24 deletions(-) diff --git a/examples/YOLOv8-ONNXRuntime-CPP/CMakeLists.txt b/examples/YOLOv8-ONNXRuntime-CPP/CMakeLists.txt index 97a9d19..41a4148 100644 --- a/examples/YOLOv8-ONNXRuntime-CPP/CMakeLists.txt +++ b/examples/YOLOv8-ONNXRuntime-CPP/CMakeLists.txt @@ -16,6 +16,10 @@ find_package(OpenCV REQUIRED) include_directories(${OpenCV_INCLUDE_DIRS}) +# -------------- Compile CUDA for FP16 inference if needed ------------------# +find_package(CUDA REQUIRED) +include_directories(${CUDA_INCLUDE_DIRS}) + # ONNXRUNTIME @@ -51,9 +55,9 @@ set(PROJECT_SOURCES add_executable(${PROJECT_NAME} ${PROJECT_SOURCES}) if(WIN32) - target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/onnxruntime.lib) + target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/onnxruntime.lib ${CUDA_LIBRARIES}) elseif(LINUX) - target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.so) + target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.so ${CUDA_LIBRARIES}) elseif(APPLE) target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.dylib) endif() diff --git a/examples/YOLOv8-ONNXRuntime-CPP/README.md b/examples/YOLOv8-ONNXRuntime-CPP/README.md index b5e02e0..87326a4 100644 --- a/examples/YOLOv8-ONNXRuntime-CPP/README.md +++ b/examples/YOLOv8-ONNXRuntime-CPP/README.md @@ -6,8 +6,7 @@ This example demonstrates how to perform inference using YOLOv8 in C++ with ONNX - Friendly for deployment in the industrial sector. - Faster than OpenCV's DNN inference on both CPU and GPU. -- Supports CUDA acceleration. -- Easy to add FP16 inference (using template functions). +- Supports FP32 and FP16 CUDA acceleration. ## Exporting YOLOv8 Models @@ -47,13 +46,12 @@ Note: The dependency on C++17 is due to the usage of the C++17 filesystem featur DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {imgsz_w, imgsz_h}, 0.1, 0.5, false}; // GPU inference DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {imgsz_w, imgsz_h}, 0.1, 0.5, true}; - // Load your image cv::Mat img = cv::imread(img_path); +// Init Inference Session +char* ret = yoloDetector->CreateSession(params); -char* ret = p1->CreateSession(params); - -ret = p->RunSession(img, res); +ret = yoloDetector->RunSession(img, res); ``` This repository should also work for YOLOv5, which needs a permute operator for the output of the YOLOv5 model, but this has not been implemented yet. diff --git a/examples/YOLOv8-ONNXRuntime-CPP/inference.cpp b/examples/YOLOv8-ONNXRuntime-CPP/inference.cpp index 953fa70..7e67cd5 100644 --- a/examples/YOLOv8-ONNXRuntime-CPP/inference.cpp +++ b/examples/YOLOv8-ONNXRuntime-CPP/inference.cpp @@ -15,6 +15,13 @@ DCSP_CORE::~DCSP_CORE() } +namespace Ort +{ + template<> + struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; +} + + template char* BlobFromImage(cv::Mat& iImg, T& iBlob) { @@ -56,7 +63,7 @@ char* DCSP_CORE::CreateSession(DCSP_INIT_PARAM &iParams) bool result = std::regex_search(iParams.ModelPath, pattern); if (result) { - Ret = "[DCSP_ONNX]:model path error.change your model path without chinese characters."; + Ret = "[DCSP_ONNX]:Model path error.Change your model path without chinese characters."; std::cout << Ret << std::endl; return Ret; } @@ -109,9 +116,7 @@ char* DCSP_CORE::CreateSession(DCSP_INIT_PARAM &iParams) } options = Ort::RunOptions{ nullptr }; WarmUpSession(); - //std::cout << OrtGetApiBase()->GetVersionString() << std::endl;; - Ret = RET_OK; - return Ret; + return RET_OK; } catch (const std::exception& e) { @@ -122,7 +127,6 @@ char* DCSP_CORE::CreateSession(DCSP_INIT_PARAM &iParams) std::strcpy(merged, result.c_str()); std::cout << merged << std::endl; delete[] merged; - //return merged; return "[DCSP_ONNX]:Create session failed."; } @@ -145,6 +149,13 @@ char* DCSP_CORE::RunSession(cv::Mat &iImg, std::vector& oResult) std::vector inputNodeDims = { 1,3,imgSize.at(0),imgSize.at(1) }; TensorProcess(starttime_1, iImg, blob, inputNodeDims, oResult); } + else + { + half* blob = new half[processedImg.total() * 3]; + BlobFromImage(processedImg, blob); + std::vector inputNodeDims = { 1,3,imgSize.at(0),imgSize.at(1) }; + TensorProcess(starttime_1, iImg, blob, inputNodeDims, oResult); + } return Ret; } @@ -169,7 +180,8 @@ char* DCSP_CORE::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std delete blob; switch (modelType) { - case 1: + case 1://V8_ORIGIN_FP32 + case 4://V8_ORIGIN_FP16 { int strideNum = outputNodeDims[2]; int signalResultNum = outputNodeDims[1]; @@ -243,15 +255,13 @@ char* DCSP_CORE::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std break; } } - char* Ret = RET_OK; - return Ret; + return RET_OK; } char* DCSP_CORE::WarmUpSession() { clock_t starttime_1 = clock(); - char* Ret = RET_OK; cv::Mat iImg = cv::Mat(cv::Size(imgSize.at(0), imgSize.at(1)), CV_8UC3); cv::Mat processedImg; PostProcess(iImg, imgSize, processedImg); @@ -270,5 +280,20 @@ char* DCSP_CORE::WarmUpSession() std::cout << "[DCSP_ONNX(CUDA)]: " << "Cuda warm-up cost " << post_process_time << " ms. " << std::endl; } } - return Ret; + else + { + half* blob = new half[iImg.total() * 3]; + BlobFromImage(processedImg, blob); + std::vector YOLO_input_node_dims = { 1,3,imgSize.at(0),imgSize.at(1) }; + Ort::Value input_tensor = Ort::Value::CreateTensor(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob, 3 * imgSize.at(0) * imgSize.at(1), YOLO_input_node_dims.data(), YOLO_input_node_dims.size()); + auto output_tensors = session->Run(options, inputNodeNames.data(), &input_tensor, 1, outputNodeNames.data(), outputNodeNames.size()); + delete[] blob; + clock_t starttime_4 = clock(); + double post_process_time = (double)(starttime_4 - starttime_1) / CLOCKS_PER_SEC * 1000; + if (cudaEnable) + { + std::cout << "[DCSP_ONNX(CUDA)]: " << "Cuda warm-up cost " << post_process_time << " ms. " << std::endl; + } + } + return RET_OK; } diff --git a/examples/YOLOv8-ONNXRuntime-CPP/inference.h b/examples/YOLOv8-ONNXRuntime-CPP/inference.h index b30f9f0..a1db199 100644 --- a/examples/YOLOv8-ONNXRuntime-CPP/inference.h +++ b/examples/YOLOv8-ONNXRuntime-CPP/inference.h @@ -13,6 +13,7 @@ #include #include #include "onnxruntime_cxx_api.h" +#include enum MODEL_TYPE @@ -21,7 +22,10 @@ enum MODEL_TYPE YOLO_ORIGIN_V5 = 0, YOLO_ORIGIN_V8 = 1,//only support v8 detector currently YOLO_POSE_V8 = 2, - YOLO_CLS_V8 = 3 + YOLO_CLS_V8 = 3, + YOLO_ORIGIN_V8_HALF = 4, + YOLO_POSE_V8_HALF = 5, + YOLO_CLS_V8_HALF = 6 }; diff --git a/examples/YOLOv8-ONNXRuntime-CPP/main.cpp b/examples/YOLOv8-ONNXRuntime-CPP/main.cpp index f4ba03e..c2839fd 100644 --- a/examples/YOLOv8-ONNXRuntime-CPP/main.cpp +++ b/examples/YOLOv8-ONNXRuntime-CPP/main.cpp @@ -82,13 +82,15 @@ int read_coco_yaml(DCSP_CORE*& p) int main() { - DCSP_CORE* p1 = new DCSP_CORE; + DCSP_CORE* yoloDetector = new DCSP_CORE; std::string model_path = "yolov8n.onnx"; - read_coco_yaml(p1); - // GPU inference + read_coco_yaml(yoloDetector); + // GPU FP32 inference DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {640, 640}, 0.1, 0.5, true }; + // GPU FP16 inference + // DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8_HALF, {640, 640}, 0.1, 0.5, true }; // CPU inference // DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {640, 640}, 0.1, 0.5, false }; - p1->CreateSession(params); - file_iterator(p1); + yoloDetector->CreateSession(params); + file_iterator(yoloDetector); }