// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // Summary: The Ort C++ API is a header only wrapper around the Ort C API. // // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so // all the resources follow RAII and do not leak memory. // // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them // until you assign an instance that actually holds an underlying object. // // For Ort objects only move assignment between objects is allowed, there are no copy constructors. // Some objects have explicit 'Clone' methods for this purpose. // // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments // by value or by reference. ConstXXXX types are restricted to const only interfaces. // // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces. // // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code. #pragma once #include "onnxruntime_c_api.h" #include "onnxruntime_float16.h" #include #include #include #include #include #include #include #include #include #include #ifdef ORT_NO_EXCEPTIONS #include #endif /** \brief All C++ Onnxruntime APIs are defined inside this namespace * */ namespace Ort { /** \brief All C++ methods that can fail will throw an exception of this type * * If ORT_NO_EXCEPTIONS is defined, then any error will result in a call to abort() */ struct Exception : std::exception { Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} OrtErrorCode GetOrtErrorCode() const { return code_; } const char* what() const noexcept override { return message_.c_str(); } private: std::string message_; OrtErrorCode code_; }; #ifdef ORT_NO_EXCEPTIONS // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors. // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW #ifndef ORT_CXX_API_THROW #define ORT_CXX_API_THROW(string, code) \ do { \ std::cerr << Ort::Exception(string, code) \ .what() \ << std::endl; \ abort(); \ } while (false) #endif #else #define ORT_CXX_API_THROW(string, code) \ throw Ort::Exception(string, code) #endif // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, // it's in a template so that we can define a global variable in a header and make // it transparent to the users of the API. template struct Global { static const OrtApi* api_; }; // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. template #ifdef ORT_API_MANUAL_INIT const OrtApi* Global::api_{}; inline void InitApi() noexcept { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is // required by C++ APIs. // // Example mycustomop.cc: // // #define ORT_API_MANUAL_INIT // #include // #undef ORT_API_MANUAL_INIT // // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) { // Ort::InitApi(api_base->GetApi(ORT_API_VERSION)); // // ... // } // inline void InitApi(const OrtApi* api) noexcept { Global::api_ = api; } #else #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers. // Please define ORT_API_MANUAL_INIT if it conerns you. #pragma warning(disable : 26426) #endif const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif #endif /// This returns a reference to the OrtApi interface in use inline const OrtApi& GetApi() noexcept { return *Global::api_; } /// /// This function returns the onnxruntime version string /// /// version string major.minor.rev std::string GetVersionString(); /// /// This function returns the onnxruntime build information: including git branch, /// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags. /// /// string std::string GetBuildInfoString(); /// /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and /// returns a vector of strings representing the available execution providers. /// /// vector of strings std::vector GetAvailableProviders(); /** \brief IEEE 754 half-precision floating point data type * * \details This struct is used for converting float to float16 and back * so the user could feed inputs and fetch outputs using these type. * * The size of the structure should align with uint16_t and one can freely cast * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. * * \code{.unparsed} * // This example demonstrates converion from float to float16 * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; * std::vector fp16_values; * fp16_values.reserve(std::size(values)); * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values), * [](float value) { return Ort::Float16_t(value); }); * * \endcode */ struct Float16_t : onnxruntime_float16::Float16Impl { private: /// /// Constructor from a 16-bit representation of a float16 value /// No conversion is done here. /// /// 16-bit representation constexpr explicit Float16_t(uint16_t v) noexcept { val = v; } public: using Base = onnxruntime_float16::Float16Impl; /// /// Default constructor /// Float16_t() = default; /// /// Explicit conversion to uint16_t representation of float16. /// /// uint16_t bit representation of float16 /// new instance of Float16_t constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); } /// /// __ctor from float. Float is converted into float16 16-bit representation. /// /// float value explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); } /// /// Converts float16 to float /// /// float representation of float16 value float ToFloat() const noexcept { return Base::ToFloatImpl(); } /// /// Checks if the value is negative /// /// true if negative using Base::IsNegative; /// /// Tests if the value is NaN /// /// true if NaN using Base::IsNaN; /// /// Tests if the value is finite /// /// true if finite using Base::IsFinite; /// /// Tests if the value represents positive infinity. /// /// true if positive infinity using Base::IsPositiveInfinity; /// /// Tests if the value represents negative infinity /// /// true if negative infinity using Base::IsNegativeInfinity; /// /// Tests if the value is either positive or negative infinity. /// /// True if absolute value is infinity using Base::IsInfinity; /// /// Tests if the value is NaN or zero. Useful for comparisons. /// /// True if NaN or zero. using Base::IsNaNOrZero; /// /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). /// /// True if so using Base::IsNormal; /// /// Tests if the value is subnormal (denormal). /// /// True if so using Base::IsSubnormal; /// /// Creates an instance that represents absolute value. /// /// Absolute value using Base::Abs; /// /// Creates a new instance with the sign flipped. /// /// Flipped sign instance using Base::Negate; /// /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check /// for two values by or'ing the private bits together and stripping the sign. They are both zero, /// and therefore equivalent, if the resulting value is still zero. /// /// first value /// second value /// True if both arguments represent zero using Base::AreZero; /// /// User defined conversion operator. Converts Float16_t to float. /// explicit operator float() const noexcept { return ToFloat(); } using Base::operator==; using Base::operator!=; using Base::operator<; }; static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); /** \brief bfloat16 (Brain Floating Point) data type * * \details This struct is used for converting float to bfloat16 and back * so the user could feed inputs and fetch outputs using these type. * * The size of the structure should align with uint16_t and one can freely cast * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. * * \code{.unparsed} * // This example demonstrates converion from float to float16 * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; * std::vector bfp16_values; * bfp16_values.reserve(std::size(values)); * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values), * [](float value) { return Ort::BFloat16_t(value); }); * * \endcode */ struct BFloat16_t : onnxruntime_float16::BFloat16Impl { private: /// /// Constructor from a uint16_t representation of bfloat16 /// used in FromBits() to escape overload resolution issue with /// constructor from float. /// No conversion is done. /// /// 16-bit bfloat16 value constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; } public: using Base = onnxruntime_float16::BFloat16Impl; BFloat16_t() = default; /// /// Explicit conversion to uint16_t representation of bfloat16. /// /// uint16_t bit representation of bfloat16 /// new instance of BFloat16_t static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); } /// /// __ctor from float. Float is converted into bfloat16 16-bit representation. /// /// float value explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); } /// /// Converts bfloat16 to float /// /// float representation of bfloat16 value float ToFloat() const noexcept { return Base::ToFloatImpl(); } /// /// Checks if the value is negative /// /// true if negative using Base::IsNegative; /// /// Tests if the value is NaN /// /// true if NaN using Base::IsNaN; /// /// Tests if the value is finite /// /// true if finite using Base::IsFinite; /// /// Tests if the value represents positive infinity. /// /// true if positive infinity using Base::IsPositiveInfinity; /// /// Tests if the value represents negative infinity /// /// true if negative infinity using Base::IsNegativeInfinity; /// /// Tests if the value is either positive or negative infinity. /// /// True if absolute value is infinity using Base::IsInfinity; /// /// Tests if the value is NaN or zero. Useful for comparisons. /// /// True if NaN or zero. using Base::IsNaNOrZero; /// /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). /// /// True if so using Base::IsNormal; /// /// Tests if the value is subnormal (denormal). /// /// True if so using Base::IsSubnormal; /// /// Creates an instance that represents absolute value. /// /// Absolute value using Base::Abs; /// /// Creates a new instance with the sign flipped. /// /// Flipped sign instance using Base::Negate; /// /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check /// for two values by or'ing the private bits together and stripping the sign. They are both zero, /// and therefore equivalent, if the resulting value is still zero. /// /// first value /// second value /// True if both arguments represent zero using Base::AreZero; /// /// User defined conversion operator. Converts BFloat16_t to float. /// explicit operator float() const noexcept { return ToFloat(); } // We do not have an inherited impl for the below operators // as the internal class implements them a little differently bool operator==(const BFloat16_t& rhs) const noexcept; bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); } bool operator<(const BFloat16_t& rhs) const noexcept; }; static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); /** \brief float8e4m3fn (Float8 Floating Point) data type * \details It is necessary for type dispatching to make use of C++ API * The type is implicitly convertible to/from uint8_t. * See https://onnx.ai/onnx/technical/float8.html for further details. */ struct Float8E4M3FN_t { uint8_t value; constexpr Float8E4M3FN_t() noexcept : value(0) {} constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {} constexpr operator uint8_t() const noexcept { return value; } // nan values are treated like any other value for operator ==, != constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; }; constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; }; }; static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match"); /** \brief float8e4m3fnuz (Float8 Floating Point) data type * \details It is necessary for type dispatching to make use of C++ API * The type is implicitly convertible to/from uint8_t. * See https://onnx.ai/onnx/technical/float8.html for further details. */ struct Float8E4M3FNUZ_t { uint8_t value; constexpr Float8E4M3FNUZ_t() noexcept : value(0) {} constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {} constexpr operator uint8_t() const noexcept { return value; } // nan values are treated like any other value for operator ==, != constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; }; constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; }; }; static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match"); /** \brief float8e5m2 (Float8 Floating Point) data type * \details It is necessary for type dispatching to make use of C++ API * The type is implicitly convertible to/from uint8_t. * See https://onnx.ai/onnx/technical/float8.html for further details. */ struct Float8E5M2_t { uint8_t value; constexpr Float8E5M2_t() noexcept : value(0) {} constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {} constexpr operator uint8_t() const noexcept { return value; } // nan values are treated like any other value for operator ==, != constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; }; constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; }; }; static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match"); /** \brief float8e5m2fnuz (Float8 Floating Point) data type * \details It is necessary for type dispatching to make use of C++ API * The type is implicitly convertible to/from uint8_t. * See https://onnx.ai/onnx/technical/float8.html for further details. */ struct Float8E5M2FNUZ_t { uint8_t value; constexpr Float8E5M2FNUZ_t() noexcept : value(0) {} constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {} constexpr operator uint8_t() const noexcept { return value; } // nan values are treated like any other value for operator ==, != constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; }; constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; }; }; static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match"); namespace detail { // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type // This can't be done in the C API since C doesn't have function overloading. #define ORT_DEFINE_RELEASE(NAME) \ inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); } ORT_DEFINE_RELEASE(Allocator); ORT_DEFINE_RELEASE(MemoryInfo); ORT_DEFINE_RELEASE(CustomOpDomain); ORT_DEFINE_RELEASE(ThreadingOptions); ORT_DEFINE_RELEASE(Env); ORT_DEFINE_RELEASE(RunOptions); ORT_DEFINE_RELEASE(Session); ORT_DEFINE_RELEASE(SessionOptions); ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); ORT_DEFINE_RELEASE(SequenceTypeInfo); ORT_DEFINE_RELEASE(MapTypeInfo); ORT_DEFINE_RELEASE(TypeInfo); ORT_DEFINE_RELEASE(Value); ORT_DEFINE_RELEASE(ModelMetadata); ORT_DEFINE_RELEASE(IoBinding); ORT_DEFINE_RELEASE(ArenaCfg); ORT_DEFINE_RELEASE(Status); ORT_DEFINE_RELEASE(OpAttr); ORT_DEFINE_RELEASE(Op); ORT_DEFINE_RELEASE(KernelInfo); #undef ORT_DEFINE_RELEASE /** \brief This is a tagging template type. Use it with Base to indicate that the C++ interface object * has no ownership of the underlying C object. */ template struct Unowned { using Type = T; }; /** \brief Used internally by the C++ API. C++ wrapper types inherit from this. * This is a zero cost abstraction to wrap the C API objects and delete them on destruction. * * All of the C++ classes * a) serve as containers for pointers to objects that are created by the underlying C API. * Their size is just a pointer size, no need to dynamically allocate them. Use them by value. * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects. * they would release objects owned automatically when going out of scope, they are move-only. * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers. * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else * such as Onnxruntime or instances of XXXX classes. * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used * in C++ code. * */ /// /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction. /// template struct Base { using contained_type = T; constexpr Base() = default; constexpr explicit Base(contained_type* p) noexcept : p_{p} {} ~Base() { OrtRelease(p_); } Base(const Base&) = delete; Base& operator=(const Base&) = delete; Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } Base& operator=(Base&& v) noexcept { OrtRelease(p_); p_ = v.release(); return *this; } constexpr operator contained_type*() const noexcept { return p_; } /// \brief Relinquishes ownership of the contained C object pointer /// The underlying object is not destroyed contained_type* release() { T* p = p_; p_ = nullptr; return p; } protected: contained_type* p_{}; }; // Undefined. For const types use Base> template struct Base; /// /// Covers unowned pointers owned by either the ORT /// or some other instance of CPP wrappers. /// Used for ConstXXX and UnownedXXXX types that are copyable. /// Also convenient to wrap raw OrtXX pointers . /// /// template struct Base> { using contained_type = typename Unowned::Type; constexpr Base() = default; constexpr explicit Base(contained_type* p) noexcept : p_{p} {} ~Base() = default; Base(const Base&) = default; Base& operator=(const Base&) = default; Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } Base& operator=(Base&& v) noexcept { p_ = nullptr; std::swap(p_, v.p_); return *this; } constexpr operator contained_type*() const noexcept { return p_; } protected: contained_type* p_{}; }; // Light functor to release memory with OrtAllocator struct AllocatedFree { OrtAllocator* allocator_; explicit AllocatedFree(OrtAllocator* allocator) : allocator_(allocator) {} void operator()(void* ptr) const { if (ptr) allocator_->Free(allocator_, ptr); } }; } // namespace detail struct AllocatorWithDefaultOptions; struct Env; struct TypeInfo; struct Value; struct ModelMetadata; /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators * and release them at the end of the scope. The lifespan of the given allocator * must eclipse the lifespan of AllocatedStringPtr instance */ using AllocatedStringPtr = std::unique_ptr; /** \brief The Status that holds ownership of OrtStatus received from C API * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate * constructors to construct an instance of a Status object from exceptions. */ struct Status : detail::Base { explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message. std::string GetErrorMessage() const; OrtErrorCode GetErrorCode() const; bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status. }; /** \brief The ThreadingOptions * * The ThreadingOptions used for set global threadpools' options of The Env. */ struct ThreadingOptions : detail::Base { /// \brief Wraps OrtApi::CreateThreadingOptions ThreadingOptions(); /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads); /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads); /// \brief Wraps OrtApi::SetGlobalSpinControl ThreadingOptions& SetGlobalSpinControl(int allow_spinning); /// \brief Wraps OrtApi::SetGlobalDenormalAsZero ThreadingOptions& SetGlobalDenormalAsZero(); /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options); /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); }; /** \brief The Env (Environment) * * The Env holds the logging state used by all other objects. * Note: One Env must be created before using any other Onnxruntime functionality */ struct Env : detail::Base { explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used /// \brief Wraps OrtApi::CreateEnv Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); /// \brief Wraps OrtApi::CreateEnvWithCustomLogger Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); /// \brief C Interop Helper explicit Env(OrtEnv* p) : Base{p} {} Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 }; /** \brief Custom Op Domain * */ struct CustomOpDomain : detail::Base { explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used /// \brief Wraps OrtApi::CreateCustomOpDomain explicit CustomOpDomain(const char* domain); // This does not take ownership of the op, simply registers it. void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add }; /** \brief RunOptions * */ struct RunOptions : detail::Base { explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used RunOptions(); ///< Wraps OrtApi::CreateRunOptions RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance * * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error * Wraps OrtApi::RunOptionsSetTerminate */ RunOptions& SetTerminate(); /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating * * Wraps OrtApi::RunOptionsUnsetTerminate */ RunOptions& UnsetTerminate(); }; namespace detail { // Utility function that returns a SessionOption config entry key for a specific custom operator. // Ex: custom_op.[custom_op_name].[config] std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config); } // namespace detail /// /// Class that represents session configuration entries for one or more custom operators. /// /// Example: /// Ort::CustomOpConfigs op_configs; /// op_configs.AddConfig("my_custom_op", "device_type", "CPU"); /// /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary. /// struct CustomOpConfigs { CustomOpConfigs() = default; ~CustomOpConfigs() = default; CustomOpConfigs(const CustomOpConfigs&) = default; CustomOpConfigs& operator=(const CustomOpConfigs&) = default; CustomOpConfigs(CustomOpConfigs&& o) = default; CustomOpConfigs& operator=(CustomOpConfigs&& o) = default; /** \brief Adds a session configuration entry/value for a specific custom operator. * * \param custom_op_name The name of the custom operator for which to add a configuration entry. * Must match the name returned by the CustomOp's GetName() method. * \param config_key The name of the configuration entry. * \param config_value The value of the configuration entry. * \return A reference to this object to enable call chaining. */ CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value); /** \brief Returns a flattened map of custom operator configuration entries and their values. * * The keys has been flattened to include both the custom operator name and the configuration entry key name. * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair * {"my_op.key", "value"}. * * \return An unordered map of flattened configurations. */ const std::unordered_map& GetFlattenedConfigs() const; private: std::unordered_map flat_configs_; }; /** \brief Options object used when creating a new Session object * * Wraps ::OrtSessionOptions object and methods */ struct SessionOptions; namespace detail { // we separate const-only methods because passing const ptr to non-const methods // is only discovered when inline methods are compiled which is counter-intuitive template struct ConstSessionOptionsImpl : Base { using B = Base; using B::B; SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def); }; template struct SessionOptionsImpl : ConstSessionOptionsImpl { using B = ConstSessionOptionsImpl; using B::B; SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer SessionOptionsImpl& AddExternalInitializers(const std::vector& names, const std::vector& ort_values); ///< Wraps OrtApi::AddExternalInitializers SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector>& external_initializer_file_names, const std::vector& external_initializer_file_buffer_array, const std::vector& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2 SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map& provider_options = {}); SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2. ///< The custom operator configurations are optional. If provided, custom operator configs are set via ///< OrtApi::AddSessionConfigEntry. SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options = {}); }; } // namespace detail using UnownedSessionOptions = detail::SessionOptionsImpl>; using ConstSessionOptions = detail::ConstSessionOptionsImpl>; /** \brief Wrapper around ::OrtSessionOptions * */ struct SessionOptions : detail::SessionOptionsImpl { explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl{p} {} ///< Used for interop with the C API UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; } ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; } }; /** \brief Wrapper around ::OrtModelMetadata * */ struct ModelMetadata : detail::Base { explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} ///< Used for interop with the C API /** \brief Returns a copy of the producer name. * * \param allocator to allocate memory for the copy of the name returned * \return a instance of smart pointer that would deallocate the buffer when out of scope. * The OrtAllocator instances must be valid at the point of memory release. */ AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName /** \brief Returns a copy of the graph name. * * \param allocator to allocate memory for the copy of the name returned * \return a instance of smart pointer that would deallocate the buffer when out of scope. * The OrtAllocator instances must be valid at the point of memory release. */ AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName /** \brief Returns a copy of the domain name. * * \param allocator to allocate memory for the copy of the name returned * \return a instance of smart pointer that would deallocate the buffer when out of scope. * The OrtAllocator instances must be valid at the point of memory release. */ AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain /** \brief Returns a copy of the description. * * \param allocator to allocate memory for the copy of the string returned * \return a instance of smart pointer that would deallocate the buffer when out of scope. * The OrtAllocator instances must be valid at the point of memory release. */ AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription /** \brief Returns a copy of the graph description. * * \param allocator to allocate memory for the copy of the string returned * \return a instance of smart pointer that would deallocate the buffer when out of scope. * The OrtAllocator instances must be valid at the point of memory release. */ AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription /** \brief Returns a vector of copies of the custom metadata keys. * * \param allocator to allocate memory for the copy of the string returned * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope. * The OrtAllocator instance must be valid at the point of memory release. */ std::vector GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys /** \brief Looks up a value by a key in the Custom Metadata map * * \param key zero terminated string key to lookup * \param allocator to allocate memory for the copy of the string returned * \return a instance of smart pointer that would deallocate the buffer when out of scope. * maybe nullptr if key is not found. * * The OrtAllocator instances must be valid at the point of memory release. */ AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion }; struct IoBinding; namespace detail { // we separate const-only methods because passing const ptr to non-const methods // is only discovered when inline methods are compiled which is counter-intuitive template struct ConstSessionImpl : Base { using B = Base; using B::B; size_t GetInputCount() const; ///< Returns the number of model inputs size_t GetOutputCount() const; ///< Returns the number of model outputs size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden /** \brief Returns a copy of input name at the specified index. * * \param index must less than the value returned by GetInputCount() * \param allocator to allocate memory for the copy of the name returned * \return a instance of smart pointer that would deallocate the buffer when out of scope. * The OrtAllocator instances must be valid at the point of memory release. */ AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const; /** \brief Returns a copy of output name at then specified index. * * \param index must less than the value returned by GetOutputCount() * \param allocator to allocate memory for the copy of the name returned * \return a instance of smart pointer that would deallocate the buffer when out of scope. * The OrtAllocator instances must be valid at the point of memory release. */ AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const; /** \brief Returns a copy of the overridable initializer name at then specified index. * * \param index must less than the value returned by GetOverridableInitializerCount() * \param allocator to allocate memory for the copy of the name returned * \return a instance of smart pointer that would deallocate the buffer when out of scope. * The OrtAllocator instances must be valid at the point of memory release. */ AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo }; template struct SessionImpl : ConstSessionImpl { using B = ConstSessionImpl; using B::B; /** \brief Run the model returning results in an Ort allocated vector. * * Wraps OrtApi::Run * * The caller provides a list of inputs and a list of the desired outputs to return. * * See the output logs for more information on warnings/errors that occur while processing the model. * Common errors are.. (TODO) * * \param[in] run_options * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names * \param[in] input_values Array of Value objects of length input_count that is the list of input values * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays) * \param[in] output_names Array of C style strings of length output_count that is the list of output names * \param[in] output_count Number of outputs (the size of the output_names array) * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector) */ std::vector Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, const char* const* output_names, size_t output_count); /** \brief Run the model returning results in user provided outputs * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t) */ void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, const char* const* output_names, Value* output_values, size_t output_count); void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding /** \brief Run the model asynchronously in a thread owned by intra op thread pool * * Wraps OrtApi::RunAsync * * \param[in] run_options * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names * \param[in] input_values Array of Value objects of length input_count * \param[in] input_count Number of elements in the input_names and inputs arrays * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names * \param[out] output_values Array of provided Values to be filled with outputs. * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*. * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime. * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback. * NOTE: it is customer's duty to finally release output_values and each of its member, * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer. * \param[in] output_count Number of elements in the output_names and outputs array * \param[in] callback Callback function on model run completion * \param[in] user_data User data that pass back to the callback */ void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data); /** \brief End profiling and return a copy of the profiling file name. * * \param allocator to allocate memory for the copy of the string returned * \return a instance of smart pointer that would deallocate the buffer when out of scope. * The OrtAllocator instances must be valid at the point of memory release. */ AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling }; } // namespace detail using ConstSession = detail::ConstSessionImpl>; using UnownedSession = detail::SessionImpl>; /** \brief Wrapper around ::OrtSession * */ struct Session : detail::SessionImpl { explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer ConstSession GetConst() const { return ConstSession{this->p_}; } UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } }; namespace detail { template struct MemoryInfoImpl : Base { using B = Base; using B::B; std::string GetAllocatorName() const; OrtAllocatorType GetAllocatorType() const; int GetDeviceId() const; OrtMemoryInfoDeviceType GetDeviceType() const; OrtMemType GetMemoryType() const; template bool operator==(const MemoryInfoImpl& o) const; }; } // namespace detail // Const object holder that does not own the underlying object using ConstMemoryInfo = detail::MemoryInfoImpl>; /** \brief Wrapper around ::OrtMemoryInfo * */ struct MemoryInfo : detail::MemoryInfoImpl { static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C Api MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } }; namespace detail { template struct TensorTypeAndShapeInfoImpl : Base { using B = Base; using B::B; ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount /** \deprecated use GetShape() returning std::vector * [[deprecated]] * This interface is unsafe to use */ [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape }; } // namespace detail using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl>; /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo * */ struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl { explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } }; namespace detail { template struct SequenceTypeInfoImpl : Base { using B = Base; using B::B; TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType }; } // namespace detail using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl>; /** \brief Wrapper around ::OrtSequenceTypeInfo * */ struct SequenceTypeInfo : detail::SequenceTypeInfoImpl { explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl{p} {} ///< Used for interop with the C API ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; } }; namespace detail { template struct OptionalTypeInfoImpl : Base { using B = Base; using B::B; TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo }; } // namespace detail // This is always owned by the TypeInfo and can only be obtained from it. using ConstOptionalTypeInfo = detail::OptionalTypeInfoImpl>; namespace detail { template struct MapTypeInfoImpl : detail::Base { using B = Base; using B::B; ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType }; } // namespace detail using ConstMapTypeInfo = detail::MapTypeInfoImpl>; /** \brief Wrapper around ::OrtMapTypeInfo * */ struct MapTypeInfo : detail::MapTypeInfoImpl { explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl{p} {} ///< Used for interop with the C API ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; } }; namespace detail { template struct TypeInfoImpl : detail::Base { using B = Base; using B::B; ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo ONNXType GetONNXType() const; }; } // namespace detail /// /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value. /// Provides access to const OrtTypeInfo APIs. /// using ConstTypeInfo = detail::TypeInfoImpl>; /// /// Type information that may contain either TensorTypeAndShapeInfo or /// the information about contained sequence or map depending on the ONNXType. /// struct TypeInfo : detail::TypeInfoImpl { explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl{p} {} ///< C API Interop ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; } }; namespace detail { // This structure is used to feed sparse tensor values // information for use with FillSparseTensor() API // if the data type for the sparse tensor values is numeric // use data.p_data, otherwise, use data.str pointer to feed // values. data.str is an array of const char* that are zero terminated. // number of strings in the array must match shape size. // For fully sparse tensors use shape {0} and set p_data/str // to nullptr. struct OrtSparseValuesParam { const int64_t* values_shape; size_t values_shape_len; union { const void* p_data; const char** str; } data; }; // Provides a way to pass shape in a single // argument struct Shape { const int64_t* shape; size_t shape_len; }; template struct ConstValueImpl : Base { using B = Base; using B::B; /// /// Obtains a pointer to a user defined data for experimental purposes /// template void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements Value GetValue(int index, OrtAllocator* allocator) const; /// /// This API returns a full length of string data contained within either a tensor or a sparse Tensor. /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful /// for allocating necessary memory and calling GetStringTensorContent(). /// /// total length of UTF-8 encoded bytes contained. No zero terminators counted. size_t GetStringTensorDataLength() const; /// /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate. /// The user must also allocate offsets buffer with the number of entries equal to that of the contained /// strings. /// /// Strings are always assumed to be on CPU, no X-device copy. /// /// user allocated buffer /// length in bytes of the allocated buffer /// a pointer to the offsets user allocated buffer /// count of offsets, must be equal to the number of strings contained. /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo() /// for sparse tensors void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; /// /// Returns a const typed pointer to the tensor contained data. /// No type checking is performed, the caller must ensure the type matches the tensor type. /// /// /// const pointer to data, no copies made template const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// /// /// Returns a non-typed pointer to a tensor contained data. /// /// const pointer to data, no copies made const void* GetTensorRawData() const; /// /// The API returns type information for data contained in a tensor. For sparse /// tensors it returns type information for contained non-zero values. /// It returns dense shape for sparse tensors. /// /// TypeInfo TypeInfo GetTypeInfo() const; /// /// The API returns type information for data contained in a tensor. For sparse /// tensors it returns type information for contained non-zero values. /// It returns dense shape for sparse tensors. /// /// TensorTypeAndShapeInfo TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; /// /// This API returns information about the memory allocation used to hold data. /// /// Non owning instance of MemoryInfo ConstMemoryInfo GetTensorMemoryInfo() const; /// /// The API copies UTF-8 encoded bytes for the requested string element /// contained within a tensor or a sparse tensor into a provided buffer. /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate. /// /// /// /// void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; /// /// Returns string tensor UTF-8 encoded string element. /// Use of this API is recommended over GetStringTensorElement() that takes void* buffer pointer. /// /// /// std::string std::string GetStringTensorElement(size_t element_index) const; /// /// The API returns a byte length of UTF-8 encoded string element /// contained in either a tensor or a spare tensor values. /// /// /// byte length for the specified string element size_t GetStringTensorElementLength(size_t element_index) const; #if !defined(DISABLE_SPARSE_TENSORS) /// /// The API returns the sparse data format this OrtValue holds in a sparse tensor. /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used /// the value returned is ORT_SPARSE_UNDEFINED. /// /// Format enum OrtSparseFormat GetSparseFormat() const; /// /// The API returns type and shape information for stored non-zero values of the /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer. /// /// TensorTypeAndShapeInfo values information TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const; /// /// The API returns type and shape information for the specified indices. Each supported /// indices have their own enum values even if a give format has more than one kind of indices. /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer. /// /// enum requested /// type and shape information TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const; /// /// The API retrieves a pointer to the internal indices buffer. The API merely performs /// a convenience data type casting on the return type pointer. Make sure you are requesting /// the right type, use GetSparseTensorIndicesTypeShapeInfo(); /// /// type to cast to /// requested indices kind /// number of indices entries /// Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer. template const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const; /// /// Returns true if the OrtValue contains a sparse tensor /// /// bool IsSparseTensor() const; /// /// The API returns a pointer to an internal buffer of the sparse tensor /// containing non-zero values. The API merely does casting. Make sure you /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo() /// first. /// /// numeric data types only. Use GetStringTensor*() to retrieve strings. /// a pointer to the internal values buffer. Do not free this pointer. template const R* GetSparseTensorValues() const; #endif }; template struct ValueImpl : ConstValueImpl { using B = ConstValueImpl; using B::B; /// /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer /// No type checking is performed, the caller must ensure the type matches the tensor type. /// /// non-const pointer to data, no copies made template R* GetTensorMutableData(); /// /// Returns a non-typed non-const pointer to a tensor contained data. /// /// pointer to data, no copies made void* GetTensorMutableRawData(); /// // Obtain a reference to an element of data at the location specified /// by the vector of dims. /// /// /// [in] expressed by a vecotr of dimensions offsets /// template R& At(const std::vector& location); /// /// Set all strings at once in a string tensor /// /// [in] An array of strings. Each string in this array must be null terminated. /// [in] Count of strings in s (Must match the size of \p value's tensor shape) void FillStringTensor(const char* const* s, size_t s_len); /// /// Set a single string in a string tensor /// /// [in] A null terminated UTF-8 encoded string /// [in] Index of the string in the tensor to set void FillStringTensorElement(const char* s, size_t index); /// /// Allocate if necessary and obtain a pointer to a UTF-8 /// encoded string element buffer indexed by the flat element index, /// of the specified length. /// /// This API is for advanced usage. It avoids a need to construct /// an auxiliary array of string pointers, and allows to write data directly /// (do not zero terminate). /// /// /// /// a pointer to a writable buffer char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length); #if !defined(DISABLE_SPARSE_TENSORS) /// /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor. /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user /// allocated buffers lifespan must eclipse that of the OrtValue. /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. /// /// pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors. /// number of indices entries. Use 0 for fully sparse tensors void UseCooIndices(int64_t* indices_data, size_t indices_num); /// /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor. /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user /// allocated buffers lifespan must eclipse that of the OrtValue. /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. /// /// pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors /// number of csr inner indices or 0 for fully sparse tensors /// pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors /// number of csr outer indices or 0 for fully sparse tensors void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num); /// /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor. /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user /// allocated buffers lifespan must eclipse that of the OrtValue. /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. /// /// indices shape or a {0} for fully sparse /// user allocated buffer with indices or nullptr for fully spare tensors void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data); /// /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located /// at difference device than the allocator, a X-device copy will be performed if possible. /// /// specified buffer memory description /// values buffer information. /// coo indices buffer or nullptr for fully sparse data /// number of COO indices or 0 for fully sparse data void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param, const int64_t* indices_data, size_t indices_num); /// /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located /// at difference device than the allocator, a X-device copy will be performed if possible. /// /// specified buffer memory description /// values buffer information /// csr inner indices pointer or nullptr for fully sparse tensors /// number of csr inner indices or 0 for fully sparse tensors /// pointer to csr indices data or nullptr for fully sparse tensors /// number of csr outer indices or 0 void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values, const int64_t* inner_indices_data, size_t inner_indices_num, const int64_t* outer_indices_data, size_t outer_indices_num); /// /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located /// at difference device than the allocator, a X-device copy will be performed if possible. /// /// specified buffer memory description /// values buffer information /// indices shape. use {0} for fully sparse tensors /// pointer to indices data or nullptr for fully sparse tensors void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values, const Shape& indices_shape, const int32_t* indices_data); #endif }; } // namespace detail using ConstValue = detail::ConstValueImpl>; using UnownedValue = detail::ValueImpl>; /** \brief Wrapper around ::OrtValue * */ struct Value : detail::ValueImpl { using Base = detail::ValueImpl; using OrtSparseValuesParam = detail::OrtSparseValuesParam; using Shape = detail::Shape; explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API Value(Value&&) = default; Value& operator=(Value&&) = default; ConstValue GetConst() const { return ConstValue{this->p_}; } UnownedValue GetUnowned() const { return UnownedValue{this->p_}; } /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. * \tparam T The numeric datatype. This API is not suitable for strings. * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). * \param p_data Pointer to the data buffer. * \param p_data_element_count The number of elements in the data buffer. * \param shape Pointer to the tensor shape dimensions. * \param shape_len The number of tensor shape dimensions. */ template static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. * * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). * \param p_data Pointer to the data buffer. * \param p_data_byte_count The number of bytes in the data buffer. * \param shape Pointer to the tensor shape dimensions. * \param shape_len The number of tensor shape dimensions. * \param type The data type. */ static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. * This overload will allocate the buffer for the tensor according to the supplied shape and data type. * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. * The input data would need to be copied into the allocated buffer. * This API is not suitable for strings. * * \tparam T The numeric datatype. This API is not suitable for strings. * \param allocator The allocator to use. * \param shape Pointer to the tensor shape dimensions. * \param shape_len The number of tensor shape dimensions. */ template static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); /** \brief Creates an OrtValue with a tensor using the supplied OrtAllocator. * Wraps OrtApi::CreateTensorAsOrtValue. * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. * The input data would need to be copied into the allocated buffer. * This API is not suitable for strings. * * \param allocator The allocator to use. * \param shape Pointer to the tensor shape dimensions. * \param shape_len The number of tensor shape dimensions. * \param type The data type. */ static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); /** \brief Creates an OrtValue with a Map Onnx type representation. * The API would ref-count the supplied OrtValues and they will be released * when the returned OrtValue is released. The caller may release keys and values after the call * returns. * * \param keys an OrtValue containing a tensor with primitive data type keys. * \param values an OrtValue that may contain a tensor. Ort currently supports only primitive data type values. */ static Value CreateMap(const Value& keys, const Value& values); ///< Wraps OrtApi::CreateValue /** \brief Creates an OrtValue with a Sequence Onnx type representation. * The API would ref-count the supplied OrtValues and they will be released * when the returned OrtValue is released. The caller may release the values after the call * returns. * * \param values a vector of OrtValues that must have the same Onnx value type. */ static Value CreateSequence(const std::vector& values); ///< Wraps OrtApi::CreateValue /** \brief Creates an OrtValue wrapping an Opaque type. * This is used for experimental support of non-tensor types. * * \tparam T - the type of the value. * \param domain - zero terminated utf-8 string. Domain of the type. * \param type_name - zero terminated utf-8 string. Name of the type. * \param value - the value to be wrapped. */ template static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue #if !defined(DISABLE_SPARSE_TENSORS) /// /// This is a simple forwarding method to the other overload that helps deducing /// data type enum value from the type of the buffer. /// /// numeric datatype. This API is not suitable for strings. /// Memory description where the user buffers reside (CPU vs GPU etc) /// pointer to the user supplied buffer, use nullptr for fully sparse tensors /// a would be dense shape of the tensor /// non zero values shape. Use a single 0 shape for fully sparse tensors. /// template static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, const Shape& values_shape); /// /// Creates an OrtValue instance containing SparseTensor. This constructs /// a sparse tensor that makes use of user allocated buffers. It does not make copies /// of the user provided data and does not modify it. The lifespan of user provided buffers should /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain /// a pointer to non-zero values. To fully populate the sparse tensor call UseIndices() API below /// to supply a sparse format specific indices. /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings /// can be properly copied into the allocated buffer. /// /// Memory description where the user buffers reside (CPU vs GPU etc) /// pointer to the user supplied buffer, use nullptr for fully sparse tensors /// a would be dense shape of the tensor /// non zero values shape. Use a single 0 shape for fully sparse tensors. /// data type /// Ort::Value instance containing SparseTensor static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, const Shape& values_shape, ONNXTensorElementDataType type); /// /// This is a simple forwarding method to the below CreateSparseTensor. /// This helps to specify data type enum in terms of C++ data type. /// Use CreateSparseTensor /// /// numeric data type only. String data enum must be specified explicitly. /// allocator to use /// a would be dense shape of the tensor /// Ort::Value template static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape); /// /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data. /// The data must be supplied by on of the FillSparseTensor() methods that take both non-zero values /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator. /// Use this API to create OrtValues that contain sparse tensors with all supported data types including /// strings. /// /// allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue /// a would be dense shape of the tensor /// data type /// an instance of Ort::Value static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type); #endif // !defined(DISABLE_SPARSE_TENSORS) }; /// /// Represents native memory allocation coming from one of the /// OrtAllocators registered with OnnxRuntime. /// Use it to wrap an allocation made by an allocator /// so it can be automatically released when no longer needed. /// struct MemoryAllocation { MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); ~MemoryAllocation(); MemoryAllocation(const MemoryAllocation&) = delete; MemoryAllocation& operator=(const MemoryAllocation&) = delete; MemoryAllocation(MemoryAllocation&&) noexcept; MemoryAllocation& operator=(MemoryAllocation&&) noexcept; void* get() { return p_; } size_t size() const { return size_; } private: OrtAllocator* allocator_; void* p_; size_t size_; }; namespace detail { template struct AllocatorImpl : Base { using B = Base; using B::B; void* Alloc(size_t size); MemoryAllocation GetAllocation(size_t size); void Free(void* p); ConstMemoryInfo GetInfo() const; }; } // namespace detail /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime * */ struct AllocatorWithDefaultOptions : detail::AllocatorImpl> { explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance AllocatorWithDefaultOptions(); }; /** \brief Wrapper around ::OrtAllocator * */ struct Allocator : detail::AllocatorImpl { explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance Allocator(const Session& session, const OrtMemoryInfo*); }; using UnownedAllocator = detail::AllocatorImpl>; namespace detail { namespace binding_utils { // Bring these out of template std::vector GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*); std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*); } // namespace binding_utils template struct ConstIoBindingImpl : Base { using B = Base; using B::B; std::vector GetOutputNames() const; std::vector GetOutputNames(OrtAllocator*) const; std::vector GetOutputValues() const; std::vector GetOutputValues(OrtAllocator*) const; }; template struct IoBindingImpl : ConstIoBindingImpl { using B = ConstIoBindingImpl; using B::B; void BindInput(const char* name, const Value&); void BindOutput(const char* name, const Value&); void BindOutput(const char* name, const OrtMemoryInfo*); void ClearBoundInputs(); void ClearBoundOutputs(); void SynchronizeInputs(); void SynchronizeOutputs(); }; } // namespace detail using ConstIoBinding = detail::ConstIoBindingImpl>; using UnownedIoBinding = detail::IoBindingImpl>; /** \brief Wrapper around ::OrtIoBinding * */ struct IoBinding : detail::IoBindingImpl { explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later. explicit IoBinding(Session& session); ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; } UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; } }; /*! \struct Ort::ArenaCfg * \brief it is a structure that represents the configuration of an arena based allocator * \details Please see docs/C_API.md for details */ struct ArenaCfg : detail::Base { explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used /** * Wraps OrtApi::CreateArenaCfg * \param max_mem - use 0 to allow ORT to choose the default * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default * See docs/C_API.md for details on what the following parameters mean and how to choose these values */ ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); }; // // Custom OPs (only needed to implement custom OPs) // /// /// This struct provides life time management for custom op attribute /// struct OpAttr : detail::Base { OpAttr(const char* name, const void* data, int len, OrtOpAttrType type); }; /** * Macro that logs a message using the provided logger. Throws an exception if OrtApi::Logger_LogMessage fails. * Example: ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_INFO, "Log a message"); * * \param logger The Ort::Logger instance to use. Must be a value or reference. * \param message_severity The logging severity level of the message. * \param message A null-terminated UTF-8 message to log. */ #define ORT_CXX_LOG(logger, message_severity, message) \ do { \ if (message_severity >= logger.GetLoggingSeverityLevel()) { \ Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \ static_cast(__FUNCTION__), message)); \ } \ } while (false) /** * Macro that logs a message using the provided logger. Can be used in noexcept code since errors are silently ignored. * Example: ORT_CXX_LOG_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log a message"); * * \param logger The Ort::Logger instance to use. Must be a value or reference. * \param message_severity The logging severity level of the message. * \param message A null-terminated UTF-8 message to log. */ #define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \ do { \ if (message_severity >= logger.GetLoggingSeverityLevel()) { \ static_cast(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \ static_cast(__FUNCTION__), message)); \ } \ } while (false) /** * Macro that logs a printf-like formatted message using the provided logger. Throws an exception if * OrtApi::Logger_LogMessage fails or if a formatting error occurs. * Example: ORT_CXX_LOGF(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12); * * \param logger The Ort::Logger instance to use. Must be a value or reference. * \param message_severity The logging severity level of the message. * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. * \param ... Zero or more variadic arguments referenced by the format string. */ #define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \ do { \ if (message_severity >= logger.GetLoggingSeverityLevel()) { \ Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \ static_cast(__FUNCTION__), __VA_ARGS__)); \ } \ } while (false) /** * Macro that logs a printf-like formatted message using the provided logger. Can be used in noexcept code since errors * are silently ignored. * Example: ORT_CXX_LOGF_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12); * * \param logger The Ort::Logger instance to use. Must be a value or reference. * \param message_severity The logging severity level of the message. * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. * \param ... Zero or more variadic arguments referenced by the format string. */ #define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \ do { \ if (message_severity >= logger.GetLoggingSeverityLevel()) { \ static_cast(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \ static_cast(__FUNCTION__), __VA_ARGS__)); \ } \ } while (false) /// /// This class represents an ONNX Runtime logger that can be used to log information with an /// associated severity level and source code location (file path, line number, function name). /// /// A Logger can be obtained from within custom operators by calling Ort::KernelInfo::GetLogger(). /// Instances of Ort::Logger are the size of two pointers and can be passed by value. /// /// Use the ORT_CXX_LOG macros to ensure the source code location is set properly from the callsite /// and to take advantage of a cached logging severity level that can bypass calls to the underlying C API. /// struct Logger { /** * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use. */ Logger() = default; /** * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use. */ explicit Logger(std::nullptr_t) {} /** * Creates a logger from an ::OrtLogger instance. Caches the logger's current severity level by calling * OrtApi::Logger_GetLoggingSeverityLevel. Throws an exception if OrtApi::Logger_GetLoggingSeverityLevel fails. * * \param logger The ::OrtLogger to wrap. */ explicit Logger(const OrtLogger* logger); ~Logger() = default; Logger(const Logger&) = default; Logger& operator=(const Logger&) = default; Logger(Logger&& v) noexcept = default; Logger& operator=(Logger&& v) noexcept = default; /** * Returns the logger's current severity level from the cached member. * * \return The current ::OrtLoggingLevel. */ OrtLoggingLevel GetLoggingSeverityLevel() const noexcept; /** * Logs the provided message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOG or ORT_CXX_LOG_NOEXCEPT * macros to properly set the source code location and to use the cached severity level to potentially bypass * calls to the underlying C API. * * \param log_severity_level The message's logging severity level. * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. * \param line_number The file line number in which the message is logged. Usually the value of __LINE__. * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. * \param message The message to log. * \return A Ort::Status value to indicate error or success. */ Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, const char* func_name, const char* message) const noexcept; /** * Logs a printf-like formatted message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOGF or ORT_CXX_LOGF_NOEXCEPT * macros to properly set the source code location and to use the cached severity level to potentially bypass * calls to the underlying C API. Returns an error status if a formatting error occurs. * * \param log_severity_level The message's logging severity level. * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. * \param line_number The file line number in which the message is logged. Usually the value of __LINE__. * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. * \param args Zero or more variadic arguments referenced by the format string. * \return A Ort::Status value to indicate error or success. */ template Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, const char* func_name, const char* format, Args&&... args) const noexcept; private: const OrtLogger* logger_{}; OrtLoggingLevel cached_severity_level_{}; }; /// /// This class wraps a raw pointer OrtKernelContext* that is being passed /// to the custom kernel Compute() method. Use it to safely access context /// attributes, input and output parameters with exception safety guarantees. /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc /// struct KernelContext { explicit KernelContext(OrtKernelContext* context); size_t GetInputCount() const; size_t GetOutputCount() const; // If input is optional and is not present, the method returns en empty ConstValue // which can be compared to nullptr. ConstValue GetInput(size_t index) const; // If outout is optional and is not present, the method returns en empty UnownedValue // which can be compared to nullptr. UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const; void* GetGPUComputeStream() const; Logger GetLogger() const; OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; OrtKernelContext* GetOrtKernelContext() const { return ctx_; } void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const; private: OrtKernelContext* ctx_; }; struct KernelInfo; namespace detail { namespace attr_utils { void GetAttr(const OrtKernelInfo* p, const char* name, float&); void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&); void GetAttr(const OrtKernelInfo* p, const char* name, std::string&); void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); } // namespace attr_utils template struct KernelInfoImpl : Base { using B = Base; using B::B; KernelInfo Copy() const; template // R is only implemented for float, int64_t, and string R GetAttribute(const char* name) const { R val; attr_utils::GetAttr(this->p_, name, val); return val; } template // R is only implemented for std::vector, std::vector std::vector GetAttributes(const char* name) const { std::vector result; attr_utils::GetAttrs(this->p_, name, result); return result; } Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const; size_t GetInputCount() const; size_t GetOutputCount() const; std::string GetInputName(size_t index) const; std::string GetOutputName(size_t index) const; TypeInfo GetInputTypeInfo(size_t index) const; TypeInfo GetOutputTypeInfo(size_t index) const; ConstValue GetTensorConstantInput(size_t index, int* is_constant) const; std::string GetNodeName() const; Logger GetLogger() const; }; } // namespace detail using ConstKernelInfo = detail::KernelInfoImpl>; /// /// This struct owns the OrtKernInfo* pointer when a copy is made. /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor /// and query attributes, warp the pointer with Ort::Unowned instance /// so it does not destroy the pointer the kernel does not own. /// struct KernelInfo : detail::KernelInfoImpl { explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; } }; /// /// Create and own custom defined operation. /// struct Op : detail::Base { explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used explicit Op(OrtOp*); ///< Take ownership of the OrtOp static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version, const char** type_constraint_names, const ONNXTensorElementDataType* type_constraint_values, size_t type_constraint_count, const OpAttr* attr_values, size_t attr_count, size_t input_count, size_t output_count); void Invoke(const OrtKernelContext* context, const Value* input_values, size_t input_count, Value* output_values, size_t output_count); // For easier refactoring void Invoke(const OrtKernelContext* context, const OrtValue* const* input_values, size_t input_count, OrtValue* const* output_values, size_t output_count); }; /// /// Provide access to per-node attributes and input shapes, so one could compute and set output shapes. /// struct ShapeInferContext { struct SymbolicInteger { SymbolicInteger(int64_t i) : i_(i), is_int_(true){}; SymbolicInteger(const char* s) : s_(s), is_int_(false){}; SymbolicInteger(const SymbolicInteger&) = default; SymbolicInteger(SymbolicInteger&&) = default; SymbolicInteger& operator=(const SymbolicInteger&) = default; SymbolicInteger& operator=(SymbolicInteger&&) = default; bool operator==(const SymbolicInteger& dim) const { if (is_int_ == dim.is_int_) { if (is_int_) { return i_ == dim.i_; } else { return std::string{s_} == std::string{dim.s_}; } } return false; } bool IsInt() const { return is_int_; } int64_t AsInt() const { return i_; } const char* AsSym() const { return s_; } static constexpr int INVALID_INT_DIM = -2; private: union { int64_t i_; const char* s_; }; bool is_int_; }; using Shape = std::vector; ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx); const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); } size_t GetInputCount() const { return input_shapes_.size(); } Status SetOutputShape(size_t indice, const Shape& shape); int64_t GetAttrInt(const char* attr_name); using Ints = std::vector; Ints GetAttrInts(const char* attr_name); float GetAttrFloat(const char* attr_name); using Floats = std::vector; Floats GetAttrFloats(const char* attr_name); std::string GetAttrString(const char* attr_name); using Strings = std::vector; Strings GetAttrStrings(const char* attr_name); private: const OrtOpAttr* GetAttrHdl(const char* attr_name) const; const OrtApi* ort_api_; OrtShapeInferContext* ctx_; std::vector input_shapes_; }; using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&); #define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1 template struct CustomOpBase : OrtCustomOp { CustomOpBase() { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast(this_)->GetExecutionProviderType(); }; OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputMemoryType(index); }; OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) #pragma warning(disable : 26409) #endif OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputCharacteristic(index); }; OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputCharacteristic(index); }; OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast(this_)->GetVariadicInputMinArity(); }; OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicInputHomogeneity()); }; OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast(this_)->GetVariadicOutputMinArity(); }; OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicOutputHomogeneity()); }; #ifdef __cpp_if_constexpr if constexpr (WithStatus) { #else if (WithStatus) { #endif OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr { return static_cast(this_)->CreateKernelV2(*api, info, op_kernel); }; OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { return static_cast(op_kernel)->ComputeV2(context); }; } else { OrtCustomOp::CreateKernelV2 = nullptr; OrtCustomOp::KernelComputeV2 = nullptr; OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast(op_kernel)->Compute(context); }; } SetShapeInferFn(0); OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) { return static_cast(this_)->start_ver_; }; OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { return static_cast(this_)->end_ver_; }; OrtCustomOp::GetMayInplace = nullptr; OrtCustomOp::ReleaseMayInplace = nullptr; OrtCustomOp::GetAliasMap = nullptr; OrtCustomOp::ReleaseAliasMap = nullptr; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider const char* GetExecutionProviderType() const { return nullptr; } // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below // (inputs and outputs are required by default) OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; } OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; } // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault OrtMemType GetInputMemoryType(size_t /*index*/) const { return OrtMemTypeDefault; } // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input // should expect at least 1 argument. int GetVariadicInputMinArity() const { return 1; } // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments // to a variadic input should be of the same type. bool GetVariadicInputHomogeneity() const { return true; } // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output // should produce at least 1 output value. int GetVariadicOutputMinArity() const { return 1; } // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values // produced by a variadic output should be of the same type. bool GetVariadicOutputHomogeneity() const { return true; } // Declare list of session config entries used by this Custom Op. // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs(). // This default implementation returns an empty vector of config entries. std::vector GetSessionConfigKeys() const { return std::vector{}; } template decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) { OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { ShapeInferContext ctx(&GetApi(), ort_ctx); return C::InferOutputShape(ctx); }; return {}; } template void SetShapeInferFn(...) { OrtCustomOp::InferOutputShapeFn = {}; } protected: // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. void GetSessionConfigs(std::unordered_map& out, ConstSessionOptions options) const; int start_ver_ = 1; int end_ver_ = MAX_CUSTOM_OP_END_VER; }; } // namespace Ort #include "onnxruntime_cxx_inline.h"