onnx runtime
This commit is contained in:
@ -9,7 +9,7 @@ SET(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|||||||
# SET (OpenCV_DEBUG_DLL_FILENAME opencv_world480d.dll) # change filenames
|
# SET (OpenCV_DEBUG_DLL_FILENAME opencv_world480d.dll) # change filenames
|
||||||
# SET (OpenCV_RELEASE_DLL_FILENAME opencv_world480.dll) # change filenames
|
# SET (OpenCV_RELEASE_DLL_FILENAME opencv_world480.dll) # change filenames
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
|
||||||
SET (ONNXRUNTIME_DIR /home/xdobro23/PP1/onnxruntime/onnxruntime-linux-x64-gpu-1.18.1) # onnxruntime root
|
SET (ONNXRUNTIME_DIR "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime/onnxruntime-linux-x64-gpu-1.18.1")
|
||||||
|
|
||||||
FIND_PACKAGE(OpenCV REQUIRED)
|
FIND_PACKAGE(OpenCV REQUIRED)
|
||||||
set(CMAKE_BUILD_TYPE Debug)
|
set(CMAKE_BUILD_TYPE Debug)
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
387127404e6c1d84b3468c387d864877ed1c67fe
|
21
onnxruntime/onnxruntime-linux-x64-gpu-1.18.1/LICENSE
Normal file
21
onnxruntime/onnxruntime-linux-x64-gpu-1.18.1/LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) Microsoft Corporation
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
21
onnxruntime/onnxruntime-linux-x64-gpu-1.18.1/Privacy.md
Normal file
21
onnxruntime/onnxruntime-linux-x64-gpu-1.18.1/Privacy.md
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Privacy
|
||||||
|
|
||||||
|
## Data Collection
|
||||||
|
The software may collect information about you and your use of the software and send it to Microsoft. Microsoft may use this information to provide services and improve our products and services. You may turn off the telemetry as described in the repository. There are also some features in the software that may enable you and Microsoft to collect data from users of your applications. If you use these features, you must comply with applicable law, including providing appropriate notices to users of your applications together with a copy of Microsoft's privacy statement. Our privacy statement is located at https://go.microsoft.com/fwlink/?LinkID=824704. You can learn more about data collection and use in the help documentation and our privacy statement. Your use of the software operates as your consent to these practices.
|
||||||
|
|
||||||
|
***
|
||||||
|
|
||||||
|
### Private Builds
|
||||||
|
No data collection is performed when using your private builds built from source code.
|
||||||
|
|
||||||
|
### Official Builds
|
||||||
|
ONNX Runtime does not maintain any independent telemetry collection mechanisms outside of what is provided by the platforms it supports. However, where applicable, ONNX Runtime will take advantage of platform-supported telemetry systems to collect trace events with the goal of improving product quality.
|
||||||
|
|
||||||
|
Currently telemetry is only implemented for Windows builds and is turned **ON** by default in the official builds distributed in their respective package management repositories ([see here](../README.md#binaries)). This may be expanded to cover other platforms in the future. Data collection is implemented via 'Platform Telemetry' per vendor platform providers (see [telemetry.h](../onnxruntime/core/platform/telemetry.h)).
|
||||||
|
|
||||||
|
#### Technical Details
|
||||||
|
The Windows provider uses the [TraceLogging](https://docs.microsoft.com/en-us/windows/win32/tracelogging/trace-logging-about) API for its implementation. This enables ONNX Runtime trace events to be collected by the operating system, and based on user consent, this data may be periodically sent to Microsoft servers following GDPR and privacy regulations for anonymity and data access controls.
|
||||||
|
|
||||||
|
Windows ML and onnxruntime C APIs allow Trace Logging to be turned on/off (see [API pages](../README.md#api-documentation) for details).
|
||||||
|
For information on how to enable and disable telemetry, see [C API: Telemetry](./C_API.md#telemetry).
|
||||||
|
There are equivalent APIs in the C#, Python, and Java language bindings as well.
|
61
onnxruntime/onnxruntime-linux-x64-gpu-1.18.1/README.md
Normal file
61
onnxruntime/onnxruntime-linux-x64-gpu-1.18.1/README.md
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
<p align="center"><img width="50%" src="docs/images/ONNX_Runtime_logo_dark.png" /></p>
|
||||||
|
|
||||||
|
**ONNX Runtime is a cross-platform inference and training machine-learning accelerator**.
|
||||||
|
|
||||||
|
**ONNX Runtime inference** can enable faster customer experiences and lower costs, supporting models from deep learning frameworks such as PyTorch and TensorFlow/Keras as well as classical machine learning libraries such as scikit-learn, LightGBM, XGBoost, etc. ONNX Runtime is compatible with different hardware, drivers, and operating systems, and provides optimal performance by leveraging hardware accelerators where applicable alongside graph optimizations and transforms. [Learn more →](https://www.onnxruntime.ai/docs/#onnx-runtime-for-inferencing)
|
||||||
|
|
||||||
|
**ONNX Runtime training** can accelerate the model training time on multi-node NVIDIA GPUs for transformer models with a one-line addition for existing PyTorch training scripts. [Learn more →](https://www.onnxruntime.ai/docs/#onnx-runtime-for-training)
|
||||||
|
|
||||||
|
## Get Started & Resources
|
||||||
|
|
||||||
|
* **General Information**: [onnxruntime.ai](https://onnxruntime.ai)
|
||||||
|
|
||||||
|
* **Usage documentation and tutorials**: [onnxruntime.ai/docs](https://onnxruntime.ai/docs)
|
||||||
|
|
||||||
|
* **YouTube video tutorials**: [youtube.com/@ONNXRuntime](https://www.youtube.com/@ONNXRuntime)
|
||||||
|
|
||||||
|
* [**Upcoming Release Roadmap**](https://github.com/microsoft/onnxruntime/wiki/Upcoming-Release-Roadmap)
|
||||||
|
|
||||||
|
* **Companion sample repositories**:
|
||||||
|
- ONNX Runtime Inferencing: [microsoft/onnxruntime-inference-examples](https://github.com/microsoft/onnxruntime-inference-examples)
|
||||||
|
- ONNX Runtime Training: [microsoft/onnxruntime-training-examples](https://github.com/microsoft/onnxruntime-training-examples)
|
||||||
|
|
||||||
|
## Builtin Pipeline Status
|
||||||
|
|
||||||
|
|System|Inference|Training|
|
||||||
|
|---|---|---|
|
||||||
|
|Windows|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)<br>[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=10)<br>[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=47)||
|
||||||
|
|Linux|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)<br>[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=64)<br>[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)<br>[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=45)<br>[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=55)|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=86)<br>[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=84)<br>[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=148)|
|
||||||
|
|Mac|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=13)||
|
||||||
|
|Android|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=53)||
|
||||||
|
|iOS|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=134)||
|
||||||
|
|Web|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=161)||
|
||||||
|
|Other|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=187&repoName=microsoft%2Fonnxruntime)||
|
||||||
|
|
||||||
|
## Third-party Pipeline Status
|
||||||
|
|
||||||
|
|System|Inference|Training|
|
||||||
|
|---|---|---|
|
||||||
|
|Linux|[](https://github.com/Ascend/onnxruntime/actions/workflows/build-and-test.yaml)||
|
||||||
|
|
||||||
|
## Data/Telemetry
|
||||||
|
|
||||||
|
Windows distributions of this project may collect usage data and send it to Microsoft to help improve our products and services. See the [privacy statement](docs/Privacy.md) for more details.
|
||||||
|
|
||||||
|
## Contributions and Feedback
|
||||||
|
|
||||||
|
We welcome contributions! Please see the [contribution guidelines](CONTRIBUTING.md).
|
||||||
|
|
||||||
|
For feature requests or bug reports, please file a [GitHub Issue](https://github.com/Microsoft/onnxruntime/issues).
|
||||||
|
|
||||||
|
For general discussion or questions, please use [GitHub Discussions](https://github.com/microsoft/onnxruntime/discussions).
|
||||||
|
|
||||||
|
## Code of Conduct
|
||||||
|
|
||||||
|
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||||
|
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
||||||
|
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the [MIT License](LICENSE).
|
6508
onnxruntime/onnxruntime-linux-x64-gpu-1.18.1/ThirdPartyNotices.txt
Normal file
6508
onnxruntime/onnxruntime-linux-x64-gpu-1.18.1/ThirdPartyNotices.txt
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
|||||||
|
1.18.1
|
@ -0,0 +1,100 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
// This header is to expose a context for cuda custom ops.
|
||||||
|
// By the context, a custom cuda operator could fetch existing resources,
|
||||||
|
// such as cuda stream and cudnn handle, for reusing.
|
||||||
|
|
||||||
|
// For concrete usage, pls find page here:
|
||||||
|
// https://onnxruntime.ai/docs/reference/operators/add-custom-op.html#custom-ops-for-cuda-and-rocm
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#define ORT_CUDA_CTX
|
||||||
|
|
||||||
|
#include "cuda_resource.h"
|
||||||
|
#include "core/providers/custom_op_context.h"
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#ifndef USE_CUDA_MINIMAL
|
||||||
|
#include <cublas_v2.h>
|
||||||
|
#include <cudnn.h>
|
||||||
|
#endif
|
||||||
|
namespace Ort {
|
||||||
|
|
||||||
|
namespace Custom {
|
||||||
|
|
||||||
|
struct CudaContext : public CustomOpContext {
|
||||||
|
cudaStream_t cuda_stream = {};
|
||||||
|
cudnnHandle_t cudnn_handle = {};
|
||||||
|
cublasHandle_t cublas_handle = {};
|
||||||
|
OrtAllocator* deferred_cpu_allocator = {};
|
||||||
|
// below are cuda ep options
|
||||||
|
int16_t device_id = 0;
|
||||||
|
int32_t arena_extend_strategy = 0;
|
||||||
|
int32_t cudnn_conv_algo_search = 0;
|
||||||
|
bool cudnn_conv_use_max_workspace = true;
|
||||||
|
bool cudnn_conv1d_pad_to_nc1d = false;
|
||||||
|
bool enable_skip_layer_norm_strict_mode = false;
|
||||||
|
bool prefer_nhwc = false;
|
||||||
|
bool use_tf32 = true;
|
||||||
|
|
||||||
|
void Init(const OrtKernelContext& kernel_ctx) {
|
||||||
|
cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
|
||||||
|
cudnn_handle = FetchResource<cudnnHandle_t>(kernel_ctx, CudaResource::cudnn_handle_t);
|
||||||
|
cublas_handle = FetchResource<cublasHandle_t>(kernel_ctx, CudaResource::cublas_handle_t);
|
||||||
|
deferred_cpu_allocator = FetchResource<OrtAllocator*>(kernel_ctx, CudaResource::deferred_cpu_allocator_t);
|
||||||
|
|
||||||
|
device_id = FetchResource<int16_t>(kernel_ctx, CudaResource::device_id_t);
|
||||||
|
arena_extend_strategy = FetchResource<int32_t>(kernel_ctx, CudaResource::arena_extend_strategy_t);
|
||||||
|
cudnn_conv_algo_search = FetchResource<int32_t>(kernel_ctx, CudaResource::cudnn_conv_algo_search_t);
|
||||||
|
cudnn_conv_use_max_workspace = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv_use_max_workspace_t);
|
||||||
|
|
||||||
|
cudnn_conv1d_pad_to_nc1d = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t);
|
||||||
|
enable_skip_layer_norm_strict_mode = FetchResource<bool>(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
|
||||||
|
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
|
||||||
|
use_tf32 = FetchResource<bool>(kernel_ctx, CudaResource::use_tf32_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) {
|
||||||
|
if constexpr (sizeof(T) > sizeof(void*)) {
|
||||||
|
ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT);
|
||||||
|
}
|
||||||
|
const auto& ort_api = Ort::GetApi();
|
||||||
|
void* resource = {};
|
||||||
|
OrtStatus* status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, resource_type, &resource);
|
||||||
|
if (status) {
|
||||||
|
ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resouce type: " + std::to_string(resource_type), OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||||
|
}
|
||||||
|
T t = {};
|
||||||
|
memcpy(&t, &resource, sizeof(T));
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* AllocDeferredCpuMem(size_t size) const {
|
||||||
|
if (0 == size) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
const auto& ort_api = Ort::GetApi();
|
||||||
|
void* mem = {};
|
||||||
|
auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem);
|
||||||
|
if (status) {
|
||||||
|
ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||||
|
}
|
||||||
|
return mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FreeDeferredCpuMem(void* mem) const {
|
||||||
|
if (mem) {
|
||||||
|
const auto& ort_api = Ort::GetApi();
|
||||||
|
auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem);
|
||||||
|
if (status) {
|
||||||
|
ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace Custom
|
||||||
|
} // namespace Ort
|
@ -0,0 +1,22 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#include "core/providers/resource.h"
|
||||||
|
|
||||||
|
#define ORT_CUDA_RESOUCE_VERSION 3
|
||||||
|
|
||||||
|
enum CudaResource : int {
|
||||||
|
cuda_stream_t = cuda_resource_offset, // 10000
|
||||||
|
cudnn_handle_t,
|
||||||
|
cublas_handle_t,
|
||||||
|
deferred_cpu_allocator_t,
|
||||||
|
// below are cuda ep options
|
||||||
|
device_id_t, // 10004
|
||||||
|
arena_extend_strategy_t,
|
||||||
|
cudnn_conv_algo_search_t,
|
||||||
|
cudnn_conv_use_max_workspace_t,
|
||||||
|
cudnn_conv1d_pad_to_nc1d_t,
|
||||||
|
enable_skip_layer_norm_strict_mode_t,
|
||||||
|
prefer_nhwc_t,
|
||||||
|
use_tf32_t,
|
||||||
|
};
|
@ -0,0 +1,10 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// CustomOpContext defines an interface allowing a custom op to access ep-specific resources.
|
||||||
|
struct CustomOpContext {
|
||||||
|
CustomOpContext() = default;
|
||||||
|
virtual ~CustomOpContext(){};
|
||||||
|
};
|
@ -0,0 +1,14 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
enum ResourceOffset {
|
||||||
|
cpu_resource_offset = 0,
|
||||||
|
cuda_resource_offset = 10000,
|
||||||
|
dml_resource_offset = 20000,
|
||||||
|
rocm_resource_offset = 30000,
|
||||||
|
// offsets for other ort eps
|
||||||
|
custom_ep_resource_offset = 10000000,
|
||||||
|
// offsets for customized eps
|
||||||
|
};
|
@ -0,0 +1,19 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#include "onnxruntime_c_api.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \param use_arena zero: false. non-zero: true.
|
||||||
|
*/
|
||||||
|
ORT_EXPORT
|
||||||
|
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
|
||||||
|
ORT_ALL_ARGS_NONNULL;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,540 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
namespace onnxruntime_float16 {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
enum class endian {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
little = 0,
|
||||||
|
big = 1,
|
||||||
|
native = little,
|
||||||
|
#elif defined(__GNUC__) || defined(__clang__)
|
||||||
|
little = __ORDER_LITTLE_ENDIAN__,
|
||||||
|
big = __ORDER_BIG_ENDIAN__,
|
||||||
|
native = __BYTE_ORDER__,
|
||||||
|
#else
|
||||||
|
#error onnxruntime_float16::detail::endian is not implemented in this environment.
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
endian::native == endian::little || endian::native == endian::big,
|
||||||
|
"Only little-endian or big-endian native byte orders are supported.");
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Shared implementation between public and internal classes. CRTP pattern.
|
||||||
|
/// </summary>
|
||||||
|
template <class Derived>
|
||||||
|
struct Float16Impl {
|
||||||
|
protected:
|
||||||
|
/// <summary>
|
||||||
|
/// Converts from float to uint16_t float16 representation
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="v"></param>
|
||||||
|
/// <returns></returns>
|
||||||
|
constexpr static uint16_t ToUint16Impl(float v) noexcept;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Converts float16 to float
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>float representation of float16 value</returns>
|
||||||
|
float ToFloatImpl() const noexcept;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates an instance that represents absolute value.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Absolute value</returns>
|
||||||
|
uint16_t AbsImpl() const noexcept {
|
||||||
|
return static_cast<uint16_t>(val & ~kSignMask);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a new instance with the sign flipped.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Flipped sign instance</returns>
|
||||||
|
uint16_t NegateImpl() const noexcept {
|
||||||
|
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// uint16_t special values
|
||||||
|
static constexpr uint16_t kSignMask = 0x8000U;
|
||||||
|
static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
|
||||||
|
static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
|
||||||
|
static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
|
||||||
|
static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
|
||||||
|
static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
|
||||||
|
static constexpr uint16_t kEpsilonBits = 0x4170U;
|
||||||
|
static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number
|
||||||
|
static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
|
||||||
|
static constexpr uint16_t kOneBits = 0x3C00U;
|
||||||
|
static constexpr uint16_t kMinusOneBits = 0xBC00U;
|
||||||
|
|
||||||
|
uint16_t val{0};
|
||||||
|
|
||||||
|
Float16Impl() = default;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Checks if the value is negative
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative</returns>
|
||||||
|
bool IsNegative() const noexcept {
|
||||||
|
return static_cast<int16_t>(val) < 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if NaN</returns>
|
||||||
|
bool IsNaN() const noexcept {
|
||||||
|
return AbsImpl() > kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is finite
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if finite</returns>
|
||||||
|
bool IsFinite() const noexcept {
|
||||||
|
return AbsImpl() < kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents positive infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if positive infinity</returns>
|
||||||
|
bool IsPositiveInfinity() const noexcept {
|
||||||
|
return val == kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents negative infinity
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative infinity</returns>
|
||||||
|
bool IsNegativeInfinity() const noexcept {
|
||||||
|
return val == kNegativeInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is either positive or negative infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if absolute value is infinity</returns>
|
||||||
|
bool IsInfinity() const noexcept {
|
||||||
|
return AbsImpl() == kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN or zero. Useful for comparisons.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if NaN or zero.</returns>
|
||||||
|
bool IsNaNOrZero() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs == 0 || abs > kPositiveInfinityBits);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
bool IsNormal() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||||
|
&& (abs != 0) // is not zero
|
||||||
|
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is subnormal (denormal).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
bool IsSubnormal() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||||
|
&& (abs != 0) // is not zero
|
||||||
|
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates an instance that represents absolute value.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Absolute value</returns>
|
||||||
|
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a new instance with the sign flipped.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Flipped sign instance</returns>
|
||||||
|
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||||
|
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||||
|
/// and therefore equivalent, if the resulting value is still zero.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="lhs">first value</param>
|
||||||
|
/// <param name="rhs">second value</param>
|
||||||
|
/// <returns>True if both arguments represent zero</returns>
|
||||||
|
static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
|
||||||
|
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const Float16Impl& rhs) const noexcept {
|
||||||
|
if (IsNaN() || rhs.IsNaN()) {
|
||||||
|
// IEEE defines that NaN is not equal to anything, including itself.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return val == rhs.val;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
|
||||||
|
|
||||||
|
bool operator<(const Float16Impl& rhs) const noexcept {
|
||||||
|
if (IsNaN() || rhs.IsNaN()) {
|
||||||
|
// IEEE defines that NaN is unordered with respect to everything, including itself.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool left_is_negative = IsNegative();
|
||||||
|
if (left_is_negative != rhs.IsNegative()) {
|
||||||
|
// When the signs of left and right differ, we know that left is less than right if it is
|
||||||
|
// the negative value. The exception to this is if both values are zero, in which case IEEE
|
||||||
|
// says they should be equal, even if the signs differ.
|
||||||
|
return left_is_negative && !AreZero(*this, rhs);
|
||||||
|
}
|
||||||
|
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// The following Float16_t conversions are based on the code from
|
||||||
|
// Eigen library.
|
||||||
|
|
||||||
|
// The conversion routines are Copyright (c) Fabian Giesen, 2016.
|
||||||
|
// The original license follows:
|
||||||
|
//
|
||||||
|
// Copyright (c) Fabian Giesen, 2016
|
||||||
|
// All rights reserved.
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted.
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
union float32_bits {
|
||||||
|
unsigned int u;
|
||||||
|
float f;
|
||||||
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <class Derived>
|
||||||
|
inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
|
||||||
|
detail::float32_bits f{};
|
||||||
|
f.f = v;
|
||||||
|
|
||||||
|
constexpr detail::float32_bits f32infty = {255 << 23};
|
||||||
|
constexpr detail::float32_bits f16max = {(127 + 16) << 23};
|
||||||
|
constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
|
||||||
|
constexpr unsigned int sign_mask = 0x80000000u;
|
||||||
|
uint16_t val = static_cast<uint16_t>(0x0u);
|
||||||
|
|
||||||
|
unsigned int sign = f.u & sign_mask;
|
||||||
|
f.u ^= sign;
|
||||||
|
|
||||||
|
// NOTE all the integer compares in this function can be safely
|
||||||
|
// compiled into signed compares since all operands are below
|
||||||
|
// 0x80000000. Important if you want fast straight SSE2 code
|
||||||
|
// (since there's no unsigned PCMPGTD).
|
||||||
|
|
||||||
|
if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
|
||||||
|
val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
|
||||||
|
} else { // (De)normalized number or zero
|
||||||
|
if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
|
||||||
|
// use a magic value to align our 10 mantissa bits at the bottom of
|
||||||
|
// the float. as long as FP addition is round-to-nearest-even this
|
||||||
|
// just works.
|
||||||
|
f.f += denorm_magic.f;
|
||||||
|
|
||||||
|
// and one integer subtract of the bias later, we have our final float!
|
||||||
|
val = static_cast<uint16_t>(f.u - denorm_magic.u);
|
||||||
|
} else {
|
||||||
|
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
|
||||||
|
|
||||||
|
// update exponent, rounding bias part 1
|
||||||
|
// Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
|
||||||
|
// without arithmetic overflow.
|
||||||
|
f.u += 0xc8000fffU;
|
||||||
|
// rounding bias part 2
|
||||||
|
f.u += mant_odd;
|
||||||
|
// take the bits!
|
||||||
|
val = static_cast<uint16_t>(f.u >> 13);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val |= static_cast<uint16_t>(sign >> 16);
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Derived>
|
||||||
|
inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
|
||||||
|
constexpr detail::float32_bits magic = {113 << 23};
|
||||||
|
constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
||||||
|
detail::float32_bits o{};
|
||||||
|
|
||||||
|
o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
|
||||||
|
unsigned int exp = shifted_exp & o.u; // just the exponent
|
||||||
|
o.u += (127 - 15) << 23; // exponent adjust
|
||||||
|
|
||||||
|
// handle exponent special cases
|
||||||
|
if (exp == shifted_exp) { // Inf/NaN?
|
||||||
|
o.u += (128 - 16) << 23; // extra exp adjust
|
||||||
|
} else if (exp == 0) { // Zero/Denormal?
|
||||||
|
o.u += 1 << 23; // extra exp adjust
|
||||||
|
o.f -= magic.f; // re-normalize
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to workaround the Internal Compiler Error on ARM64
|
||||||
|
// for bitwise | operator, including std::bitset
|
||||||
|
#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
|
||||||
|
if (IsNegative()) {
|
||||||
|
return -o.f;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// original code:
|
||||||
|
o.u |= (val & 0x8000U) << 16U; // sign bit
|
||||||
|
#endif
|
||||||
|
return o.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shared implementation between public and internal classes. CRTP pattern.
|
||||||
|
template <class Derived>
|
||||||
|
struct BFloat16Impl {
|
||||||
|
protected:
|
||||||
|
/// <summary>
|
||||||
|
/// Converts from float to uint16_t float16 representation
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="v"></param>
|
||||||
|
/// <returns></returns>
|
||||||
|
static uint16_t ToUint16Impl(float v) noexcept;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Converts bfloat16 to float
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>float representation of bfloat16 value</returns>
|
||||||
|
float ToFloatImpl() const noexcept;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates an instance that represents absolute value.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Absolute value</returns>
|
||||||
|
uint16_t AbsImpl() const noexcept {
|
||||||
|
return static_cast<uint16_t>(val & ~kSignMask);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a new instance with the sign flipped.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Flipped sign instance</returns>
|
||||||
|
uint16_t NegateImpl() const noexcept {
|
||||||
|
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// uint16_t special values
|
||||||
|
static constexpr uint16_t kSignMask = 0x8000U;
|
||||||
|
static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
|
||||||
|
static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
|
||||||
|
static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
|
||||||
|
static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
|
||||||
|
static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
|
||||||
|
static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
|
||||||
|
static constexpr uint16_t kEpsilonBits = 0x0080U;
|
||||||
|
static constexpr uint16_t kMinValueBits = 0xFF7FU;
|
||||||
|
static constexpr uint16_t kMaxValueBits = 0x7F7FU;
|
||||||
|
static constexpr uint16_t kRoundToNearest = 0x7FFFU;
|
||||||
|
static constexpr uint16_t kOneBits = 0x3F80U;
|
||||||
|
static constexpr uint16_t kMinusOneBits = 0xBF80U;
|
||||||
|
|
||||||
|
uint16_t val{0};
|
||||||
|
|
||||||
|
BFloat16Impl() = default;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Checks if the value is negative
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative</returns>
|
||||||
|
bool IsNegative() const noexcept {
|
||||||
|
return static_cast<int16_t>(val) < 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if NaN</returns>
|
||||||
|
bool IsNaN() const noexcept {
|
||||||
|
return AbsImpl() > kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is finite
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if finite</returns>
|
||||||
|
bool IsFinite() const noexcept {
|
||||||
|
return AbsImpl() < kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents positive infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if positive infinity</returns>
|
||||||
|
bool IsPositiveInfinity() const noexcept {
|
||||||
|
return val == kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value represents negative infinity
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if negative infinity</returns>
|
||||||
|
bool IsNegativeInfinity() const noexcept {
|
||||||
|
return val == kNegativeInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is either positive or negative infinity.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if absolute value is infinity</returns>
|
||||||
|
bool IsInfinity() const noexcept {
|
||||||
|
return AbsImpl() == kPositiveInfinityBits;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is NaN or zero. Useful for comparisons.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if NaN or zero.</returns>
|
||||||
|
bool IsNaNOrZero() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs == 0 || abs > kPositiveInfinityBits);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
bool IsNormal() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||||
|
&& (abs != 0) // is not zero
|
||||||
|
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tests if the value is subnormal (denormal).
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>True if so</returns>
|
||||||
|
bool IsSubnormal() const noexcept {
|
||||||
|
auto abs = AbsImpl();
|
||||||
|
return (abs < kPositiveInfinityBits) // is finite
|
||||||
|
&& (abs != 0) // is not zero
|
||||||
|
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates an instance that represents absolute value.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Absolute value</returns>
|
||||||
|
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a new instance with the sign flipped.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Flipped sign instance</returns>
|
||||||
|
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||||
|
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||||
|
/// and therefore equivalent, if the resulting value is still zero.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="lhs">first value</param>
|
||||||
|
/// <param name="rhs">second value</param>
|
||||||
|
/// <returns>True if both arguments represent zero</returns>
|
||||||
|
static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
|
||||||
|
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
|
||||||
|
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
|
||||||
|
// and therefore equivalent, if the resulting value is still zero.
|
||||||
|
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class Derived>
|
||||||
|
inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
|
||||||
|
uint16_t result;
|
||||||
|
if (std::isnan(v)) {
|
||||||
|
result = kPositiveQNaNBits;
|
||||||
|
} else {
|
||||||
|
auto get_msb_half = [](float fl) {
|
||||||
|
uint16_t result;
|
||||||
|
#ifdef __cpp_if_constexpr
|
||||||
|
if constexpr (detail::endian::native == detail::endian::little) {
|
||||||
|
#else
|
||||||
|
if (detail::endian::native == detail::endian::little) {
|
||||||
|
#endif
|
||||||
|
std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
|
||||||
|
} else {
|
||||||
|
std::memcpy(&result, &fl, sizeof(uint16_t));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
uint16_t upper_bits = get_msb_half(v);
|
||||||
|
union {
|
||||||
|
uint32_t U32;
|
||||||
|
float F32;
|
||||||
|
};
|
||||||
|
F32 = v;
|
||||||
|
U32 += (upper_bits & 1) + kRoundToNearest;
|
||||||
|
result = get_msb_half(F32);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Derived>
|
||||||
|
inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
|
||||||
|
if (IsNaN()) {
|
||||||
|
return std::numeric_limits<float>::quiet_NaN();
|
||||||
|
}
|
||||||
|
float result;
|
||||||
|
char* const first = reinterpret_cast<char*>(&result);
|
||||||
|
char* const second = first + sizeof(uint16_t);
|
||||||
|
#ifdef __cpp_if_constexpr
|
||||||
|
if constexpr (detail::endian::native == detail::endian::little) {
|
||||||
|
#else
|
||||||
|
if (detail::endian::native == detail::endian::little) {
|
||||||
|
#endif
|
||||||
|
std::memset(first, 0, sizeof(uint16_t));
|
||||||
|
std::memcpy(second, &val, sizeof(uint16_t));
|
||||||
|
} else {
|
||||||
|
std::memcpy(first, &val, sizeof(uint16_t));
|
||||||
|
std::memset(second, 0, sizeof(uint16_t));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnxruntime_float16
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,51 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This file defines RunOptions Config Keys and format of the Config Values.
|
||||||
|
*
|
||||||
|
* The Naming Convention for a RunOptions Config Key,
|
||||||
|
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
|
||||||
|
* Such as "ep.cuda.use_arena"
|
||||||
|
* The Config Key cannot be empty
|
||||||
|
* The maximum length of the Config Key is 128
|
||||||
|
*
|
||||||
|
* The string format of a RunOptions Config Value is defined individually for each Config.
|
||||||
|
* The maximum length of the Config Value is 1024
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Key for enabling shrinkages of user listed device memory arenas.
|
||||||
|
// Expects a list of semi-colon separated key value pairs separated by colon in the following format:
|
||||||
|
// "device_0:device_id_0;device_1:device_id_1"
|
||||||
|
// No white-spaces allowed in the provided list string.
|
||||||
|
// Currently, the only supported devices are : "cpu", "gpu" (case sensitive).
|
||||||
|
// If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled.
|
||||||
|
// Example usage: "cpu:0;gpu:0" (or) "gpu:0"
|
||||||
|
// By default, the value for this key is empty (i.e.) no memory arenas are shrunk
|
||||||
|
static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage";
|
||||||
|
|
||||||
|
// Set to '1' to not synchronize execution providers with CPU at the end of session run.
|
||||||
|
// Per default it will be set to '0'
|
||||||
|
// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
|
||||||
|
static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";
|
||||||
|
|
||||||
|
// Set HTP performance mode for QNN HTP backend before session run.
|
||||||
|
// options for HTP performance mode: "burst", "balanced", "default", "high_performance",
|
||||||
|
// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
|
||||||
|
// "sustained_high_performance". Default to "default".
|
||||||
|
static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode";
|
||||||
|
|
||||||
|
// Set HTP performance mode for QNN HTP backend post session run.
|
||||||
|
static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run";
|
||||||
|
|
||||||
|
// Set RPC control latency for QNN HTP backend
|
||||||
|
static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";
|
||||||
|
|
||||||
|
// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true.
|
||||||
|
// The value should be an integer. If the value is not set, the default value is 0 and
|
||||||
|
// ORT session only captures one cuda graph before another capture is requested.
|
||||||
|
// If the value is set to -1, cuda graph capture/replay is disabled in that run.
|
||||||
|
// User are not expected to set the value to 0 as it is reserved for internal use.
|
||||||
|
static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id";
|
@ -0,0 +1,267 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This file defines SessionOptions Config Keys and format of the Config Values.
|
||||||
|
*
|
||||||
|
* The Naming Convention for a SessionOptions Config Key,
|
||||||
|
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
|
||||||
|
* Such as "ep.cuda.use_arena"
|
||||||
|
* The Config Key cannot be empty
|
||||||
|
* The maximum length of the Config Key is 128
|
||||||
|
*
|
||||||
|
* The string format of a SessionOptions Config Value is defined individually for each Config.
|
||||||
|
* The maximum length of the Config Value is 1024
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Key for disable PrePacking,
|
||||||
|
// If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value)
|
||||||
|
static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking";
|
||||||
|
|
||||||
|
// A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session
|
||||||
|
// will be used. Use this to override the usage of env allocators on a per session level.
|
||||||
|
static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators";
|
||||||
|
|
||||||
|
// Set to 'ORT' (case sensitive) to load an ORT format model.
|
||||||
|
// If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT
|
||||||
|
static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format";
|
||||||
|
|
||||||
|
// Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set.
|
||||||
|
// If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'.
|
||||||
|
static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format";
|
||||||
|
|
||||||
|
// If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0".
|
||||||
|
// When multiple sessions are created, a main thread doesn't override changes from succeeding session options,
|
||||||
|
// but threads in session thread pools follow option changes.
|
||||||
|
// When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and
|
||||||
|
// denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool.
|
||||||
|
// Note that an alternative way not using this option at runtime is to train and export a model without denormals
|
||||||
|
// and that's recommended because turning this option on may hurt model accuracy.
|
||||||
|
static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero";
|
||||||
|
|
||||||
|
// It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not.
|
||||||
|
// "0": enable. ORT does fusion logic for QDQ format.
|
||||||
|
// "1": disable. ORT doesn't do fusion logic for QDQ format.
|
||||||
|
// Its default value is "0" unless the DirectML execution provider is registered, in which case it defaults to "1".
|
||||||
|
static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq";
|
||||||
|
|
||||||
|
// It controls whether to enable Double QDQ remover and Identical Children Consolidation
|
||||||
|
// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
|
||||||
|
// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
|
||||||
|
// Its default value is "0"
|
||||||
|
static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover";
|
||||||
|
|
||||||
|
// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been
|
||||||
|
// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the
|
||||||
|
// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to
|
||||||
|
// 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on
|
||||||
|
// other factors like whether the model was created using Quantization Aware Training or Post Training Quantization.
|
||||||
|
// As such, it's best to test to determine if enabling this works well for your scenario.
|
||||||
|
// The default value is "0"
|
||||||
|
// Available since version 1.11.
|
||||||
|
static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup";
|
||||||
|
|
||||||
|
// Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0".
|
||||||
|
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
|
||||||
|
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
|
||||||
|
|
||||||
|
// This setting controls whether to enable AheadOfTime function inlining.
|
||||||
|
// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
|
||||||
|
// as possible with the help of enabled execution providers.
|
||||||
|
// This can reduce the number of function calls and improve performance because it is done before
|
||||||
|
// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available,
|
||||||
|
// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time.
|
||||||
|
// "0": enable; "1": disable.
|
||||||
|
// Its default value is "0".
|
||||||
|
static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";
|
||||||
|
|
||||||
|
#ifdef ENABLE_TRAINING
|
||||||
|
// Specifies a list of op types for memory footprint reduction.
|
||||||
|
// The value should be a ","-delimited list of pair of
|
||||||
|
// <subgraph string: optimization strategy: number of subgraph to apply>.
|
||||||
|
// For example, "Gelu+Cast+:1:0,Dropout+:1:1".
|
||||||
|
// A valid "subgraph string" should be one subgraph representation output by ORT graph transformations.
|
||||||
|
// "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute.
|
||||||
|
// "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving"
|
||||||
|
// the memory.
|
||||||
|
static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.memory_optimizer_config";
|
||||||
|
|
||||||
|
// Specifies the config for detecting subgraphs for memory footprint reduction.
|
||||||
|
// The value should be a string contains int separated using commas. The default value is "0:0".
|
||||||
|
static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config";
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// This setting if set should contain a comma separated list of optimizers names that should be disabled.
|
||||||
|
// Optimizers may take time to execute and affect model loading time. If you feel that a specific optimizer
|
||||||
|
// does not provider runtime benefits, but affects your model loading time you may disable it using this config
|
||||||
|
// entry. This option is not enabled in ORT_MINIMAL_BUILD build.
|
||||||
|
// A list of optimizes is available in onnxruntime/core/optimizer/graph_transformer_utils.cc
|
||||||
|
//
|
||||||
|
// Default is an empty string which means no optimizers are disabled.
|
||||||
|
static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers";
|
||||||
|
|
||||||
|
// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
|
||||||
|
// Using device allocators means the memory allocation is made using malloc/new.
|
||||||
|
static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";
|
||||||
|
|
||||||
|
// Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking
|
||||||
|
// "0": thread will block if found no job to run
|
||||||
|
// "1": default, thread will spin a number of times before blocking
|
||||||
|
static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
|
||||||
|
static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";
|
||||||
|
|
||||||
|
// Key for using model bytes directly for ORT format
|
||||||
|
// If a session is created using an input byte array contains the ORT format model data,
|
||||||
|
// By default we will copy the model bytes at the time of session creation to ensure the model bytes
|
||||||
|
// buffer is valid.
|
||||||
|
// Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller
|
||||||
|
// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed.
|
||||||
|
static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly";
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Key for using the ORT format model flatbuffer bytes directly for initializers.
|
||||||
|
/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization.
|
||||||
|
/// Requires `session.use_ort_model_bytes_directly` to be true.
|
||||||
|
/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire
|
||||||
|
/// duration of the InferenceSession.
|
||||||
|
/// </summary>
|
||||||
|
static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers =
|
||||||
|
"session.use_ort_model_bytes_for_initializers";
|
||||||
|
|
||||||
|
// This should only be specified when exporting an ORT format model for use on a different platform.
|
||||||
|
// If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0"
|
||||||
|
// Available since version 1.11.
|
||||||
|
static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed";
|
||||||
|
|
||||||
|
// x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8.
|
||||||
|
// To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if
|
||||||
|
// turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512
|
||||||
|
// platforms.
|
||||||
|
static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision";
|
||||||
|
|
||||||
|
// Specifies how minimal build graph optimizations are handled in a full build.
|
||||||
|
// These optimizations are at the extended level or higher.
|
||||||
|
// Possible values and their effects are:
|
||||||
|
// "save": Save runtime optimizations when saving an ORT format model.
|
||||||
|
// "apply": Only apply optimizations available in a minimal build.
|
||||||
|
// ""/<unspecified>: Apply optimizations available in a full build.
|
||||||
|
// Available since version 1.11.
|
||||||
|
static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations =
|
||||||
|
"optimization.minimal_build_optimizations";
|
||||||
|
|
||||||
|
// Note: The options specific to an EP should be specified prior to appending that EP to the session options object in
|
||||||
|
// order for them to take effect.
|
||||||
|
|
||||||
|
// Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be
|
||||||
|
// run by the NNAPI EP.
|
||||||
|
// The value should be a ","-delimited list of op types. For example, "Add,Sub".
|
||||||
|
// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
|
||||||
|
// exclusion, set the value to "".
|
||||||
|
static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
|
||||||
|
|
||||||
|
// Enabling dynamic block-sizing for multithreading.
|
||||||
|
// With a positive value, thread pool will split a task of N iterations to blocks of size starting from:
|
||||||
|
// N / (num_of_threads * dynamic_block_base)
|
||||||
|
// As execution progresses, the size will decrease according to the diminishing residual of N,
|
||||||
|
// meaning the task will be distributed in smaller granularity for better parallelism.
|
||||||
|
// For some models, it helps to reduce the variance of E2E inference latency and boost performance.
|
||||||
|
// The feature will not function by default, specify any positive integer, e.g. "4", to enable it.
|
||||||
|
// Available since version 1.11.
|
||||||
|
static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base";
|
||||||
|
|
||||||
|
// This option allows to decrease CPU usage between infrequent
|
||||||
|
// requests and forces any TP threads spinning stop immediately when the last of
|
||||||
|
// concurrent Run() call returns.
|
||||||
|
// Spinning is restarted on the next Run() call.
|
||||||
|
// Applies only to internal thread-pools
|
||||||
|
static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop";
|
||||||
|
|
||||||
|
// "1": all inconsistencies encountered during shape and type inference
|
||||||
|
// will result in failures.
|
||||||
|
// "0": in some cases warnings will be logged but processing will continue. The default.
|
||||||
|
// May be useful to expose bugs in models.
|
||||||
|
static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference";
|
||||||
|
|
||||||
|
// "1": every model using a more recent opset than the latest released one will fail
|
||||||
|
// "0": the model may or may not work if onnxruntime cannot find an implementation, this option
|
||||||
|
// is used for development purpose.
|
||||||
|
static const char* const kOrtSessionOptionsConfigStrictAllowReleasedOpsetsOnly = "session.allow_released_opsets_only";
|
||||||
|
|
||||||
|
// The file saves configuration for partitioning node among logic streams
|
||||||
|
static const char* const kNodePartitionConfigFile = "session.node_partition_config_file";
|
||||||
|
|
||||||
|
// This Option allows setting affinities for intra op threads.
|
||||||
|
// Affinity string follows format:
|
||||||
|
// logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id
|
||||||
|
// Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to.
|
||||||
|
// e.g.1,2,3;4,5
|
||||||
|
// specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th.
|
||||||
|
// To ease the configuration, an "interval" is also allowed:
|
||||||
|
// e.g. 1-8;8-16;17-24
|
||||||
|
// orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth.
|
||||||
|
// Note:
|
||||||
|
// 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which
|
||||||
|
// is started and managed by the calling app;
|
||||||
|
// 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors,
|
||||||
|
// an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group.
|
||||||
|
// Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary.
|
||||||
|
static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities";
|
||||||
|
|
||||||
|
// This option will dump out the model to assist debugging any issues with layout transformation,
|
||||||
|
// and is primarily intended for developer usage. It is only relevant if an execution provider that requests
|
||||||
|
// NHWC layout is enabled such as NNAPI, XNNPACK or QNN.
|
||||||
|
//
|
||||||
|
// Default is off. Set to "1" to enable.
|
||||||
|
//
|
||||||
|
// If modified by layout transformation the model will be dumped after these steps:
|
||||||
|
// 1) insertion of the layout transformation Transpose nodes
|
||||||
|
// 2) after those are optimized using the transpose optimizer,
|
||||||
|
// 3) after the L1 transformers are applied to the updated graph.
|
||||||
|
// The model will be saved to filename post_layout_transform_step_<step_number>.onnx.
|
||||||
|
static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation";
|
||||||
|
|
||||||
|
// Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are
|
||||||
|
// assigned (i.e., "fallback") to the CPU EP by default.
|
||||||
|
//
|
||||||
|
// This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP.
|
||||||
|
// If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot
|
||||||
|
// fully support all of the nodes in the graph.
|
||||||
|
//
|
||||||
|
// It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation
|
||||||
|
// will also fail with an error.
|
||||||
|
//
|
||||||
|
// Option values:
|
||||||
|
// - "0": CPU EP fallback is not disabled. [DEFAULT]
|
||||||
|
// - "1": CPU EP fallback is disabled.
|
||||||
|
static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback";
|
||||||
|
|
||||||
|
// Use this config when serializing a large model after optimization to specify an external initializers file
|
||||||
|
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName =
|
||||||
|
"session.optimized_model_external_initializers_file_name";
|
||||||
|
|
||||||
|
// Use this config to control the minimum size of the initializer when externalizing it during serialization
|
||||||
|
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
|
||||||
|
"session.optimized_model_external_initializers_min_size_in_bytes";
|
||||||
|
|
||||||
|
// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
|
||||||
|
// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
|
||||||
|
// "0": disable. (default)
|
||||||
|
// "1": enable.
|
||||||
|
static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";
|
||||||
|
|
||||||
|
// Specify the file path for the Onnx model which has EP context.
|
||||||
|
// Default to original_file_name_ctx.onnx if not specified
|
||||||
|
static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";
|
||||||
|
|
||||||
|
// Flag to specify whether to dump the EP context into the Onnx model.
|
||||||
|
// "0": dump the EP context into separate file, keep the file name in the Onnx model.
|
||||||
|
// "1": dump the EP context into the Onnx model. (default).
|
||||||
|
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
|
||||||
|
|
||||||
|
// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
|
||||||
|
// Option values:
|
||||||
|
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
|
||||||
|
// - "1": Gemm FastMath mode is enabled.
|
||||||
|
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
|
@ -0,0 +1,731 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
// This file contains the training c apis.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "onnxruntime_c_api.h"
|
||||||
|
|
||||||
|
/** \page training_c_cpp_api Training C & C++ APIs
|
||||||
|
*
|
||||||
|
* Training C and C++ APIs are an extension of the \ref c_cpp_api "onnxruntime core C and C++ APIs" and should be used in conjunction with them.
|
||||||
|
*
|
||||||
|
* In order to train a model with onnxruntime, the following training artifacts must be generated:
|
||||||
|
* - The training onnx model
|
||||||
|
* - The checkpoint file
|
||||||
|
* - The optimizer onnx model
|
||||||
|
* - The eval onnx model model (optional)
|
||||||
|
*
|
||||||
|
* These training artifacts can be generated as part of an offline step using the python [utilities](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md) made available in the `onnxruntime-training` python package.
|
||||||
|
*
|
||||||
|
* After these artifacts have been generated, the C and C++ utilities listed in this documentation can be leveraged to perform training.
|
||||||
|
*
|
||||||
|
* If any problem is encountered, please create an [issue](https://github.com/microsoft/onnxruntime/issues/new) with your scenario and requirements, and we will be sure to respond and follow up on the request.
|
||||||
|
*
|
||||||
|
* <h1>Training C API</h1>
|
||||||
|
*
|
||||||
|
* ::OrtTrainingApi - Training C API functions.
|
||||||
|
*
|
||||||
|
* This C structure contains functions that enable users to perform training with onnxruntime.
|
||||||
|
*
|
||||||
|
* _Sample Code_:
|
||||||
|
*
|
||||||
|
* ```c
|
||||||
|
* #include <onnxruntime_training_api.h>
|
||||||
|
*
|
||||||
|
* OrtApi* g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||||
|
* OrtTrainingApi* g_ort_training_api = g_ort_api->GetTrainingApi(ORT_API_VERSION);
|
||||||
|
*
|
||||||
|
* OrtEnv* env = NULL;
|
||||||
|
* g_ort_api->CreateEnv(logging_level, logid, &env);
|
||||||
|
* OrtSessionOptions* session_options = NULL;
|
||||||
|
* g_ort_api->CreateSessionOptions(&session_options);
|
||||||
|
*
|
||||||
|
* OrtCheckpointState* state = NULL;
|
||||||
|
* g_ort_training_api->LoadCheckpoint(path_to_checkpoint, &state);
|
||||||
|
*
|
||||||
|
* OrtTrainingSession* training_session = NULL;
|
||||||
|
* g_ort_training_api->CreateTrainingSession(env, session_options, training_model_path,
|
||||||
|
* state, eval_model_path, optimizer_model_path,
|
||||||
|
* &training_session);
|
||||||
|
* // Training loop
|
||||||
|
* {
|
||||||
|
* g_ort_training_api->TrainStep(...);
|
||||||
|
* g_ort_training_api->OptimizerStep(...);
|
||||||
|
* g_ort_training_api->LazyResetGrad(...);
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* g_ort_training_api->ExportModelForInferencing(training_session, inference_model_path, ...);
|
||||||
|
* g_ort_training_api->SaveCheckpoint(state, path_to_checkpoint, false);
|
||||||
|
*
|
||||||
|
* g_ort_training_api->ReleaseTrainingSession(training_session);
|
||||||
|
* g_ort_training_api->ReleaseCheckpointState(state);
|
||||||
|
* ```
|
||||||
|
*
|
||||||
|
* > **Note**
|
||||||
|
* > The ::OrtCheckpointState contains the entire training state that the ::OrtTrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::OrtCheckpointState instance must outlive the lifetime of the ::OrtTrainingSession instance.
|
||||||
|
*
|
||||||
|
* <h1>Training C++ API</h1>
|
||||||
|
*
|
||||||
|
* @ref TrainingCpp - Training C++ API classes and functions.
|
||||||
|
*
|
||||||
|
* These C++ classes and functions enable users to perform training with onnxruntime.
|
||||||
|
*
|
||||||
|
* _Sample Code_:
|
||||||
|
*
|
||||||
|
* ```cc
|
||||||
|
* #include <onnxruntime_training_cxx_api.h>
|
||||||
|
*
|
||||||
|
* Ort::Env env;
|
||||||
|
* Ort::SessionOptions session_options;
|
||||||
|
*
|
||||||
|
* auto state = Ort::CheckpointState::LoadCheckpoint(path_to_checkpoint);
|
||||||
|
* auto training_session = Ort::TrainingSession(env, session_options, state, training_model_path,
|
||||||
|
* eval_model_path, optimizer_model_path);
|
||||||
|
*
|
||||||
|
* // Training Loop
|
||||||
|
* {
|
||||||
|
* training_session.TrainStep(...);
|
||||||
|
* training_session.OptimizerStep(...);
|
||||||
|
* training_session.LazyResetGrad(...);
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* training_session->ExportModelForInferencing(inference_model_path, ...);
|
||||||
|
* Ort::CheckpointState::SaveCheckpoint(state, path_to_checkpoint, false);
|
||||||
|
* ```
|
||||||
|
* > **Note**
|
||||||
|
* > The ::Ort::CheckpointState contains the entire training state that the ::Ort::TrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::Ort::CheckpointState instance must outlive the lifetime of the ::Ort::TrainingSession instance.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/** @defgroup TrainingC Ort Training C API
|
||||||
|
* @{
|
||||||
|
*/
|
||||||
|
ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models.
|
||||||
|
ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session.
|
||||||
|
|
||||||
|
/** \brief Type of property to be added to or returned from the ::OrtCheckpointState.
|
||||||
|
*/
|
||||||
|
typedef enum OrtPropertyType {
|
||||||
|
OrtIntProperty = 0,
|
||||||
|
OrtFloatProperty = 1,
|
||||||
|
OrtStringProperty = 2,
|
||||||
|
} OrtPropertyType;
|
||||||
|
|
||||||
|
/** \brief The Training C API that holds onnxruntime training function pointers
|
||||||
|
*
|
||||||
|
* All the Training C API functions are defined inside this structure as pointers to functions.
|
||||||
|
* Call OrtApi::GetTrainingApi to get a pointer to this struct.
|
||||||
|
*
|
||||||
|
* \nosubgrouping
|
||||||
|
*/
|
||||||
|
struct OrtTrainingApi {
|
||||||
|
/// \name Accessing The Training Session State
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
|
||||||
|
*
|
||||||
|
* This function will parse a checkpoint file, pull relevant data and load the training
|
||||||
|
* state into the checkpoint_state. This checkpoint state can then be used to create the
|
||||||
|
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
|
||||||
|
* session will resume training from the given checkpoint state.
|
||||||
|
* \note Note that the training session created with a checkpoint state uses this state to store the entire
|
||||||
|
* training state (including model parameters, its gradients, the optimizer states and the properties).
|
||||||
|
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
|
||||||
|
* \note Note that the checkpoint file can be either the complete checkpoint or the nominal checkpoint.
|
||||||
|
*
|
||||||
|
* \param[in] checkpoint_path Path to the checkpoint file
|
||||||
|
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
|
||||||
|
_Outptr_ OrtCheckpointState** checkpoint_state);
|
||||||
|
|
||||||
|
/** \brief Save the given state to a checkpoint file on disk.
|
||||||
|
*
|
||||||
|
* This function serializes the provided checkpoint state to a file on disk.
|
||||||
|
* This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume
|
||||||
|
* training from this snapshot of the state.
|
||||||
|
*
|
||||||
|
* \param[in] checkpoint_state The checkpoint state to save.
|
||||||
|
* \param[in] checkpoint_path Path to the checkpoint file.
|
||||||
|
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path,
|
||||||
|
const bool include_optimizer_state);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Implementing The Training Loop
|
||||||
|
/// @{
|
||||||
|
/** \brief Create a training session that can be used to begin or resume training.
|
||||||
|
*
|
||||||
|
* This function creates a training session based on the env and session options provided that can
|
||||||
|
* begin or resume training from a given checkpoint state for the given onnx models.
|
||||||
|
* The checkpoint state represents the parameters of the training session which will be moved
|
||||||
|
* to the device specified by the user through the session options (if necessary).
|
||||||
|
* The training session requires four training artifacts
|
||||||
|
* - The training onnx model
|
||||||
|
* - The evaluation onnx model (optional)
|
||||||
|
* - The optimizer onnx model
|
||||||
|
* - The checkpoint file
|
||||||
|
*
|
||||||
|
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
|
||||||
|
*
|
||||||
|
* \param[in] env Environment to be used for the training session.
|
||||||
|
* \param[in] options Session options that the user can customize for this training session.
|
||||||
|
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
|
||||||
|
* \param[in] train_model_path Model to be used to perform training.
|
||||||
|
* \param[in] eval_model_path Model to be used to perform evaluation.
|
||||||
|
* \param[in] optimizer_model_path Model to be used to perform gradient descent.
|
||||||
|
* \param[out] out Created training session.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
|
||||||
|
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
|
||||||
|
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
|
||||||
|
_Outptr_result_maybenull_ OrtTrainingSession** out);
|
||||||
|
|
||||||
|
/** \brief Create a training session that can be used to begin or resume training.
|
||||||
|
* This api provides a way to load all the training artifacts from buffers instead of files.
|
||||||
|
*
|
||||||
|
* \param[in] env Environment to be used for the training session.
|
||||||
|
* \param[in] options Session options that the user can customize for this training session.
|
||||||
|
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
|
||||||
|
* \param[in] train_model_data Buffer containing the model data to be used to perform training
|
||||||
|
* \param[in] train_data_length Length of the buffer containing train_model_data
|
||||||
|
* \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation
|
||||||
|
* \param[in] eval_data_length Length of the buffer containing eval_model_data
|
||||||
|
* \param[in] optim_model_data Buffer containing the model data to be used to perform weight update
|
||||||
|
* \param[in] optim_data_length Length of the buffer containing optim_model_data
|
||||||
|
* \param[out] out Created training session.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env,
|
||||||
|
_In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||||
|
_In_ const void* train_model_data, size_t train_data_length,
|
||||||
|
_In_ const void* eval_model_data, size_t eval_data_length,
|
||||||
|
_In_ const void* optim_model_data, size_t optim_data_length,
|
||||||
|
_Outptr_result_maybenull_ OrtTrainingSession** out);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Model IO Information
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** \brief Retrieves the number of user outputs in the training model.
|
||||||
|
*
|
||||||
|
* This function returns the number of outputs of the training model so that the user can
|
||||||
|
* allocate space for the number of outputs when OrtTrainingApi::TrainStep is invoked.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[out] out Number of user outputs in the training model.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||||
|
|
||||||
|
/** \brief Retrieves the number of user outputs in the eval model.
|
||||||
|
*
|
||||||
|
* This function returns the number of outputs of the eval model so that the user can
|
||||||
|
* allocate space for the number of outputs when OrtTrainingApi::EvalStep is invoked.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[out] out Number of user outputs in the eval model.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||||
|
|
||||||
|
/** \brief Retrieves the names of user outputs in the training model.
|
||||||
|
*
|
||||||
|
* This function returns the names of outputs of the training model that can be associated with the OrtValue(s)
|
||||||
|
* returned by the OrtTrainingApi::TrainStep function.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] index Index of the output name requested.
|
||||||
|
* \param[in] allocator Allocator to use to allocate the memory for the name.
|
||||||
|
* \param[out] output Name of the training model output at the given index.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
|
||||||
|
|
||||||
|
/** \brief Retrieves the names of user outputs in the eval model.
|
||||||
|
*
|
||||||
|
* This function returns the names of outputs of the eval model that can be associated with the OrtValue(s) returned
|
||||||
|
* by the OrtTrainingApi::EvalStep function.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] index Index of the output name requested.
|
||||||
|
* \param[in] allocator Allocator to use to allocate the memory for the name.
|
||||||
|
* \param[out] output Name of the eval model output at the given index.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Implementing The Training Loop
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** \brief Reset the gradients of all trainable parameters to zero lazily.
|
||||||
|
*
|
||||||
|
* This function sets the internal state of the training session such that the gradients of the trainable
|
||||||
|
* parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are
|
||||||
|
* computed on the next invocation of the next OrtTrainingApi::TrainStep.
|
||||||
|
*
|
||||||
|
* \param[in] session The `this` pointer to the training session.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session);
|
||||||
|
|
||||||
|
/** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
|
||||||
|
*
|
||||||
|
* This function performs a training step that computes the outputs of the training model and the gradients
|
||||||
|
* of the trainable parameters for the given inputs. The train step is performed based on the training model
|
||||||
|
* that was provided to the training session.
|
||||||
|
* The OrtTrainingApi::TrainStep is equivalent of running forward propagation and backward propagation in a single
|
||||||
|
* step.
|
||||||
|
* The gradients computed are stored inside the training session state so they can be later consumed
|
||||||
|
* by the OrtTrainingApi::OptimizerStep function.
|
||||||
|
* The gradients can be lazily reset by invoking the OrtTrainingApi::LazyResetGrad function.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] run_options Run options for this training step.
|
||||||
|
* \param[in] inputs_len Number of user inputs to the training model.
|
||||||
|
* \param[in] inputs The user inputs to the training model.
|
||||||
|
* \param[in] outputs_len Number of user outputs expected from this training step.
|
||||||
|
* \param[out] outputs User outputs computed by train step.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||||
|
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
||||||
|
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||||
|
|
||||||
|
/** \brief Computes the outputs for the eval model for the given inputs
|
||||||
|
*
|
||||||
|
* This function performs an eval step that computes the outputs of the eval model for the given inputs.
|
||||||
|
* The eval step is performed based on the eval model that was provided to the training session.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] run_options Run options for this eval step.
|
||||||
|
* \param[in] inputs_len Number of user inputs to the eval model.
|
||||||
|
* \param[in] inputs The user inputs to the eval model.
|
||||||
|
* \param[in] outputs_len Number of user outputs expected from this eval step.
|
||||||
|
* \param[out] outputs User outputs computed by eval step.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||||
|
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
||||||
|
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||||
|
|
||||||
|
/** \brief Sets the learning rate for this training session.
|
||||||
|
*
|
||||||
|
* This function allows users to set the learning rate for the training session. The current
|
||||||
|
* learning rate is maintained by the training session and can be overwritten by invoking
|
||||||
|
* this function with the desired learning rate. This function should not be used when a valid
|
||||||
|
* learning rate scheduler is registered. It should be used either to set the learning rate
|
||||||
|
* derived from a custom learning rate scheduler or to set a constant learning rate to be used
|
||||||
|
* throughout the training session.
|
||||||
|
* \note Please note that this function does not set the initial learning rate that may be needed
|
||||||
|
* by the predefined learning rate schedulers. To set the initial learning rate for learning
|
||||||
|
* rate schedulers, please look at the function OrtTrainingApi::RegisterLinearLRScheduler.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] learning_rate Desired learning rate to be set.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
|
||||||
|
|
||||||
|
/** \brief Gets the current learning rate for this training session.
|
||||||
|
*
|
||||||
|
* This function allows users to get the learning rate for the training session. The current
|
||||||
|
* learning rate is maintained by the training session, and users can query it for the purpose
|
||||||
|
* of implementing their own learning rate schedulers.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[out] learning_rate Learning rate currently in use by the training session.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate);
|
||||||
|
|
||||||
|
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
|
||||||
|
*
|
||||||
|
* This function performs the weight update step that updates the trainable parameters such that they
|
||||||
|
* take a step in the direction of their gradients (gradient descent). The optimizer step is performed
|
||||||
|
* based on the optimizer model that was provided to the training session.
|
||||||
|
* The updated parameters are stored inside the training state so that they can be used by the next
|
||||||
|
* OrtTrainingApi::TrainStep function call.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] run_options Run options for this optimizer step.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
|
||||||
|
_In_opt_ const OrtRunOptions* run_options);
|
||||||
|
|
||||||
|
/** \brief Registers a linear learning rate scheduler for the training session.
|
||||||
|
*
|
||||||
|
* Register a linear learning rate scheduler that decays the learning rate by linearly updated
|
||||||
|
* multiplicative factor from the initial learning rate set on the training session to 0. The decay
|
||||||
|
* is performed after the initial warm up phase where the learning rate is linearly incremented
|
||||||
|
* from 0 to the initial learning rate provided.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] warmup_step_count Warmup steps for LR warmup.
|
||||||
|
* \param[in] total_step_count Total step count.
|
||||||
|
* \param[in] initial_lr The initial learning rate to be used by the training session.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count,
|
||||||
|
_In_ const int64_t total_step_count, _In_ const float initial_lr);
|
||||||
|
|
||||||
|
/** \brief Update the learning rate based on the registered learing rate scheduler.
|
||||||
|
*
|
||||||
|
* Takes a scheduler step that updates the learning rate that is being used by the training session.
|
||||||
|
* This function should typically be called before invoking the optimizer step for each round,
|
||||||
|
* or as determined necessary to update the learning rate being used by the training session.
|
||||||
|
* \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this
|
||||||
|
* function.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Accessing The Training Session State
|
||||||
|
/// @{
|
||||||
|
/** \brief Retrieves the size of all the parameters.
|
||||||
|
*
|
||||||
|
* Calculates the total number of primitive (datatype of the parameters) elements of all the parameters in the
|
||||||
|
* training state.
|
||||||
|
* When trainable_only argument is true, the size is calculated for trainable params only.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[out] out Size of all parameter elements.
|
||||||
|
* \param[in] trainable_only Whether to skip non-trainable parameters
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only);
|
||||||
|
|
||||||
|
/** \brief Copy all parameters to a contiguous buffer held by the argument parameters_buffer
|
||||||
|
*
|
||||||
|
* The parameters_buffer has to be of the size given by GetParametersSize api call,
|
||||||
|
* with matching setting for the argument trainable_only. All the target parameters must be of the same
|
||||||
|
* datatype. The OrtValue must be pre-allocated onto
|
||||||
|
* the desired device. This is a complementary function to OrtTrainingApi::CopyBufferToParameters.
|
||||||
|
* Parameter ordering is preserved.
|
||||||
|
* User is responsible for allocating and freeing the resources used by the parameters_buffer.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] trainable_only Whether to skip non-trainable parameters
|
||||||
|
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy onto.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess,
|
||||||
|
_Inout_ OrtValue* parameters_buffer, bool trainable_only);
|
||||||
|
|
||||||
|
/** \brief Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state
|
||||||
|
*
|
||||||
|
* The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call,
|
||||||
|
* with matching setting for trainable_only argument. All the target parameters must be of the same
|
||||||
|
* datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer
|
||||||
|
* and can be used to load updated buffer values onto the training state.
|
||||||
|
* Parameter ordering is preserved.
|
||||||
|
* User is responsible for allocating and freeing the resources used by the parameters_buffer.
|
||||||
|
* In case the training session was created with a nominal checkpoint, invoking this function is required
|
||||||
|
* to load the updated parameters onto the checkpoint to complete it.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] trainable_only Whether to skip non-trainable parameters
|
||||||
|
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy from.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess,
|
||||||
|
_Inout_ OrtValue* parameters_buffer, bool trainable_only);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Release Training Resources
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** \brief Frees up the memory used up by the training session.
|
||||||
|
*
|
||||||
|
* This function frees up any memory that was allocated in the training session. The training
|
||||||
|
* session can no longer be used after this call.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_CLASS_RELEASE(TrainingSession);
|
||||||
|
|
||||||
|
/** \brief Frees up the memory used up by the checkpoint state.
|
||||||
|
*
|
||||||
|
* This function frees up any memory that was allocated in the checkpoint state. The checkpoint
|
||||||
|
* state can no longer be used after this call.
|
||||||
|
* \note Note that the checkpoint state must be released only after the training session has been released.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_CLASS_RELEASE(CheckpointState);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Prepare For Inferencing
|
||||||
|
/// @{
|
||||||
|
/** \brief Export a model that can be used for inferencing.
|
||||||
|
*
|
||||||
|
* If the training session was provided with an eval model, the training session can generate
|
||||||
|
* an inference model if it knows the inference graph outputs. The input inference graph outputs
|
||||||
|
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
|
||||||
|
* The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
|
||||||
|
* \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession
|
||||||
|
* and expects that this path still be valid.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] inference_model_path Path where the inference model should be serialized to.
|
||||||
|
* \param[in] graph_outputs_len Size of the graph output names array.
|
||||||
|
* \param[in] graph_output_names Names of the outputs that are needed in the inference model.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
|
||||||
|
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
|
||||||
|
_In_reads_(graph_outputs_len) const char* const* graph_output_names);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Training Utilities
|
||||||
|
/// @{
|
||||||
|
/** \brief Sets the seed used for random number generation in Onnxruntime.
|
||||||
|
*
|
||||||
|
* Use this function to generate reproducible results. It should be noted that completely reproducible
|
||||||
|
* results are not guaranteed.
|
||||||
|
*
|
||||||
|
* \param[in] seed The seed to be set.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(SetSeed, _In_ const int64_t seed);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Model IO Information
|
||||||
|
/// @{
|
||||||
|
/** \brief Retrieves the number of user inputs in the training model.
|
||||||
|
*
|
||||||
|
* This function returns the number of inputs of the training model so that the user can accordingly
|
||||||
|
* allocate the OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[out] out Number of user inputs in the training model.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||||
|
|
||||||
|
/** \brief Retrieves the number of user inputs in the eval model.
|
||||||
|
*
|
||||||
|
* This function returns the number of inputs of the eval model so that the user can accordingly
|
||||||
|
* allocate the OrtValue(s) provided to the OrtTrainingApi::EvalStep function.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[out] out Number of user inputs in the eval model.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||||
|
|
||||||
|
/** \brief Retrieves the name of the user input at given index in the training model.
|
||||||
|
*
|
||||||
|
* This function returns the names of inputs of the training model that can be associated with the
|
||||||
|
* OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] index The index of the training model input name requested.
|
||||||
|
* \param[in] allocator The allocator to use to allocate the memory for the requested name.
|
||||||
|
* \param[out] output Name of the user input for the training model at the given index.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
|
||||||
|
_In_ OrtAllocator* allocator, _Outptr_ char** output);
|
||||||
|
|
||||||
|
/** \brief Retrieves the name of the user input at given index in the eval model.
|
||||||
|
*
|
||||||
|
* This function returns the names of inputs of the eval model that can be associated with the OrtValue(s) provided
|
||||||
|
* to the OrtTrainingApi::EvalStep function.
|
||||||
|
*
|
||||||
|
* \param[in] sess The `this` pointer to the training session.
|
||||||
|
* \param[in] index The index of the eval model input name requested.
|
||||||
|
* \param[in] allocator The allocator to use to allocate the memory for the requested name.
|
||||||
|
* \param[out] output Name of the user input for the eval model at the given index.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
|
||||||
|
_In_ OrtAllocator* allocator, _Outptr_ char** output);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Accessing The Training Session State
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** \brief Adds or updates the given property to/in the checkpoint state.
|
||||||
|
*
|
||||||
|
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||||
|
* state by the user by calling this function with the corresponding property name and value.
|
||||||
|
* The given property name must be unique to be able to successfully add the property.
|
||||||
|
*
|
||||||
|
* \param[in] checkpoint_state The checkpoint state which should hold the property.
|
||||||
|
* \param[in] property_name Name of the property being added or updated.
|
||||||
|
* \param[in] property_type Type of the property associated with the given name.
|
||||||
|
* \param[in] property_value Property value associated with the given name.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||||
|
_In_ const char* property_name, _In_ enum OrtPropertyType property_type,
|
||||||
|
_In_ void* property_value);
|
||||||
|
|
||||||
|
/** \brief Gets the property value associated with the given name from the checkpoint state.
|
||||||
|
*
|
||||||
|
* Gets the property value from an existing entry in the checkpoint state. The property must
|
||||||
|
* exist in the checkpoint state to be able to retrieve it successfully.
|
||||||
|
*
|
||||||
|
* \param[in] checkpoint_state The checkpoint state that is currently holding the property.
|
||||||
|
* \param[in] property_name Name of the property being retrieved.
|
||||||
|
* \param[in] allocator Allocator used to allocate the memory for the property_value.
|
||||||
|
* \param[out] property_type Type of the property associated with the given name.
|
||||||
|
* \param[out] property_value Property value associated with the given name.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state,
|
||||||
|
_In_ const char* property_name, _Inout_ OrtAllocator* allocator,
|
||||||
|
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Accessing The Training Session State
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** \brief Load a checkpoint state from a buffer into checkpoint_state.
|
||||||
|
*
|
||||||
|
* This function will parse a checkpoint bytes buffer, pull relevant data and load the training
|
||||||
|
* state into the checkpoint_state. This checkpoint state can then be used to create the
|
||||||
|
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
|
||||||
|
* session will resume training from the given checkpoint state.
|
||||||
|
* \note Note that the training session created with a checkpoint state uses this state to store the entire
|
||||||
|
* training state (including model parameters, its gradients, the optimizer states and the properties).
|
||||||
|
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
|
||||||
|
*
|
||||||
|
* \param[in] checkpoint_buffer Path to the checkpoint bytes buffer.
|
||||||
|
* \param[in] num_bytes Number of bytes in the checkpoint buffer.
|
||||||
|
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
|
||||||
|
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||||
|
|
||||||
|
/** \brief Retrieves the type and shape information of the parameter associated with the given parameter name.
|
||||||
|
*
|
||||||
|
* This function retrieves the type and shape of the parameter associated with the given parameter name.
|
||||||
|
* The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully.
|
||||||
|
*
|
||||||
|
* \param[in] checkpoint_state The checkpoint state.
|
||||||
|
* \param[in] parameter_name Name of the parameter being retrieved.
|
||||||
|
* \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
|
||||||
|
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
|
||||||
|
|
||||||
|
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
||||||
|
*
|
||||||
|
* This function updates a model parameter in the checkpoint state with the given parameter data.
|
||||||
|
* The training session must be already created with the checkpoint state that contains the parameter
|
||||||
|
* being updated. The given parameter is copied over to the registered device for the training session.
|
||||||
|
* The parameter must exist in the checkpoint state to be able to update it successfully.
|
||||||
|
*
|
||||||
|
* \param[in] checkpoint_state The checkpoint state.
|
||||||
|
* \param[in] parameter_name Name of the parameter being updated.
|
||||||
|
* \param[in] parameter The parameter data that should replace the existing parameter data.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||||
|
_In_ const char* parameter_name, _In_ OrtValue* parameter);
|
||||||
|
|
||||||
|
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
||||||
|
*
|
||||||
|
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
|
||||||
|
* The parameter is copied over and returned as an OrtValue. The training session must be already created
|
||||||
|
* with the checkpoint state that contains the parameter being retrieved.
|
||||||
|
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
|
||||||
|
*
|
||||||
|
* \param[in] checkpoint_state The checkpoint state.
|
||||||
|
* \param[in] parameter_name Name of the parameter being retrieved.
|
||||||
|
* \param[in] allocator Allocator used to allocate the memory for the parameter.
|
||||||
|
* \param[out] parameter The parameter data that is retrieved from the checkpoint state.
|
||||||
|
*
|
||||||
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
|
||||||
|
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
|
||||||
|
_Outptr_ OrtValue** parameter);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct OrtTrainingApi OrtTrainingApi;
|
||||||
|
|
||||||
|
/// @}
|
@ -0,0 +1,418 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "onnxruntime_training_c_api.h"
|
||||||
|
#include <optional>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
namespace Ort::detail {
|
||||||
|
|
||||||
|
#define ORT_DECLARE_TRAINING_RELEASE(NAME) \
|
||||||
|
void OrtRelease(Ort##NAME* ptr);
|
||||||
|
|
||||||
|
// These release methods must be forward declared before including onnxruntime_cxx_api.h
|
||||||
|
// otherwise class Base won't be aware of them
|
||||||
|
ORT_DECLARE_TRAINING_RELEASE(CheckpointState);
|
||||||
|
ORT_DECLARE_TRAINING_RELEASE(TrainingSession);
|
||||||
|
|
||||||
|
} // namespace Ort::detail
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h"
|
||||||
|
|
||||||
|
namespace Ort {
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// This function returns the C training api struct with the pointers to the ort training C functions.
|
||||||
|
/// If using C++, please use the class instances instead of invoking the C functions directly.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>OrtTrainingApi struct with ort training C function pointers.</returns>
|
||||||
|
inline const OrtTrainingApi& GetTrainingApi() { return *GetApi().GetTrainingApi(ORT_API_VERSION); }
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
#define ORT_DEFINE_TRAINING_RELEASE(NAME) \
|
||||||
|
inline void OrtRelease(Ort##NAME* ptr) { GetTrainingApi().Release##NAME(ptr); }
|
||||||
|
|
||||||
|
ORT_DEFINE_TRAINING_RELEASE(CheckpointState);
|
||||||
|
ORT_DEFINE_TRAINING_RELEASE(TrainingSession);
|
||||||
|
|
||||||
|
#undef ORT_DECLARE_TRAINING_RELEASE
|
||||||
|
#undef ORT_DEFINE_TRAINING_RELEASE
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
using Property = std::variant<int64_t, float, std::string>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \defgroup TrainingCpp Ort Training C++ API
|
||||||
|
* @{
|
||||||
|
*/
|
||||||
|
|
||||||
|
/** \brief Holds the state of the training session.
|
||||||
|
*
|
||||||
|
* This class holds the entire training session state that includes model parameters, their gradients,
|
||||||
|
* optimizer parameters, and user properties. The Ort::TrainingSession leverages the Ort::CheckpointState
|
||||||
|
* by accessing and updating the contained training state.
|
||||||
|
* \note Note that the training session created with a checkpoint state uses this state to store the entire
|
||||||
|
* training state (including model parameters, its gradients, the optimizer states and the properties).
|
||||||
|
* The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required
|
||||||
|
* that the checkpoint state outlive the lifetime of the training session.
|
||||||
|
* \note Note that the checkpoint state can be either the complete checkpoint state or the nominal checkpoint
|
||||||
|
* state depending on the version provided while loading the checkpoint.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class CheckpointState : public detail::Base<OrtCheckpointState> {
|
||||||
|
private:
|
||||||
|
CheckpointState(OrtCheckpointState* checkpoint_state) { p_ = checkpoint_state; }
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Construct the checkpoint state by loading the checkpoint by calling LoadCheckpoint
|
||||||
|
CheckpointState() = delete;
|
||||||
|
|
||||||
|
/// \name Accessing The Training Session State
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
|
||||||
|
*
|
||||||
|
* This function will parse a checkpoint file, pull relevant data and load the training
|
||||||
|
* state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the
|
||||||
|
* training session by instantiating Ort::TrainingSession. By doing so, the training session will resume
|
||||||
|
* training from the given checkpoint state.
|
||||||
|
*
|
||||||
|
* \param[in] path_to_checkpoint Path to the checkpoint file
|
||||||
|
* \return Ort::CheckpointState object which holds the state of the training session parameters.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
static CheckpointState LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint);
|
||||||
|
|
||||||
|
/** \brief Load a checkpoint state from a buffer.
|
||||||
|
*
|
||||||
|
* This function will parse a checkpoint buffer, pull relevant data and load the training
|
||||||
|
* state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the
|
||||||
|
* training session by instantiating Ort::TrainingSession. By doing so, the training session will resume
|
||||||
|
* training from the given checkpoint state.
|
||||||
|
*
|
||||||
|
* \param[in] buffer Buffer containing the checkpoint data.
|
||||||
|
* \return Ort::CheckpointState object which holds the state of the training session parameters.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
static CheckpointState LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer);
|
||||||
|
|
||||||
|
/** \brief Save the given state to a checkpoint file on disk.
|
||||||
|
*
|
||||||
|
* This function serializes the provided checkpoint state to a file on disk.
|
||||||
|
* This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume
|
||||||
|
* training from this snapshot of the state.
|
||||||
|
*
|
||||||
|
* \param[in] checkpoint_state The checkpoint state to save.
|
||||||
|
* \param[in] path_to_checkpoint Path to the checkpoint file.
|
||||||
|
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
static void SaveCheckpoint(const CheckpointState& checkpoint_state,
|
||||||
|
const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
|
||||||
|
const bool include_optimizer_state = false);
|
||||||
|
|
||||||
|
/** \brief Adds or updates the given property to/in the checkpoint state.
|
||||||
|
*
|
||||||
|
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||||
|
* state by the user by calling this function with the corresponding property name and value.
|
||||||
|
* The given property name must be unique to be able to successfully add the property.
|
||||||
|
*
|
||||||
|
* \param[in] property_name Name of the property being added or updated.
|
||||||
|
* \param[in] property_value Property value associated with the given name.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void AddProperty(const std::string& property_name, const Property& property_value);
|
||||||
|
|
||||||
|
/** \brief Gets the property value associated with the given name from the checkpoint state.
|
||||||
|
*
|
||||||
|
* Gets the property value from an existing entry in the checkpoint state. The property must
|
||||||
|
* exist in the checkpoint state to be able to retrieve it successfully.
|
||||||
|
*
|
||||||
|
* \param[in] property_name Name of the property being retrieved.
|
||||||
|
* \return Property value associated with the given property name.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
Property GetProperty(const std::string& property_name);
|
||||||
|
|
||||||
|
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
||||||
|
*
|
||||||
|
* This function updates a model parameter in the checkpoint state with the given parameter data.
|
||||||
|
* The training session must be already created with the checkpoint state that contains the parameter
|
||||||
|
* being updated. The given parameter is copied over to the registered device for the training session.
|
||||||
|
* The parameter must exist in the checkpoint state to be able to update it successfully.
|
||||||
|
*
|
||||||
|
* \param[in] parameter_name Name of the parameter being updated.
|
||||||
|
* \param[in] parameter The parameter data that should replace the existing parameter data.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void UpdateParameter(const std::string& parameter_name, const Value& parameter);
|
||||||
|
|
||||||
|
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
||||||
|
*
|
||||||
|
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
|
||||||
|
* The parameter is copied over to the provided OrtValue. The training session must be already created
|
||||||
|
* with the checkpoint state that contains the parameter being retrieved.
|
||||||
|
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
|
||||||
|
*
|
||||||
|
* \param[in] parameter_name Name of the parameter being retrieved.
|
||||||
|
* \return The parameter data that is retrieved from the checkpoint state.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
Value GetParameter(const std::string& parameter_name);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
};
|
||||||
|
|
||||||
|
/** \brief Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
|
||||||
|
*
|
||||||
|
* The training session requires four training artifacts
|
||||||
|
* - The training onnx model
|
||||||
|
* - The evaluation onnx model (optional)
|
||||||
|
* - The optimizer onnx model
|
||||||
|
* - The checkpoint file
|
||||||
|
*
|
||||||
|
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class TrainingSession : public detail::Base<OrtTrainingSession> {
|
||||||
|
private:
|
||||||
|
size_t training_model_output_count_, eval_model_output_count_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// \name Constructing the Training Session
|
||||||
|
/// @{
|
||||||
|
/** \brief Create a training session that can be used to begin or resume training.
|
||||||
|
*
|
||||||
|
* This constructor instantiates the training session based on the env and session options provided that can
|
||||||
|
* begin or resume training from a given checkpoint state for the given onnx models.
|
||||||
|
* The checkpoint state represents the parameters of the training session which will be moved
|
||||||
|
* to the device specified by the user through the session options (if necessary).
|
||||||
|
*
|
||||||
|
* \param[in] env Env to be used for the training session.
|
||||||
|
* \param[in] session_options SessionOptions that the user can customize for this training session.
|
||||||
|
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
|
||||||
|
* \param[in] train_model_path Model to be used to perform training.
|
||||||
|
* \param[in] eval_model_path Model to be used to perform evaluation.
|
||||||
|
* \param[in] optimizer_model_path Model to be used to perform gradient descent.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
|
||||||
|
const std::basic_string<ORTCHAR_T>& train_model_path,
|
||||||
|
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path = std::nullopt,
|
||||||
|
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path = std::nullopt);
|
||||||
|
|
||||||
|
/** \brief Create a training session that can be used to begin or resume training.
|
||||||
|
* This constructor allows the users to load the models from buffers instead of files.
|
||||||
|
*
|
||||||
|
* \param[in] env Env to be used for the training session.
|
||||||
|
* \param[in] session_options SessionOptions that the user can customize for this training session.
|
||||||
|
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
|
||||||
|
* \param[in] train_model_data Buffer containing training model data.
|
||||||
|
* \param[in] eval_model_data Buffer containing evaluation model data.
|
||||||
|
* \param[in] optim_model_data Buffer containing optimizer model (used for performing weight/parameter update).
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
|
||||||
|
const std::vector<uint8_t>& train_model_data, const std::vector<uint8_t>& eval_model_data = {},
|
||||||
|
const std::vector<uint8_t>& optim_model_data = {});
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Implementing The Training Loop
|
||||||
|
/// @{
|
||||||
|
/** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
|
||||||
|
*
|
||||||
|
* This function performs a training step that computes the outputs of the training model and the gradients
|
||||||
|
* of the trainable parameters for the given inputs. The train step is performed based on the training model
|
||||||
|
* that was provided to the training session.
|
||||||
|
* The Ort::TrainingSession::TrainStep is equivalent of running forward propagation and backward propagation in a single
|
||||||
|
* step.
|
||||||
|
* The gradients computed are stored inside the training session state so they can be later consumed
|
||||||
|
* by the Ort::TrainingSession::OptimizerStep function.
|
||||||
|
* The gradients can be lazily reset by invoking the Ort::TrainingSession::LazyResetGrad function.
|
||||||
|
*
|
||||||
|
* \param[in] input_values The user inputs to the training model.
|
||||||
|
* \return A std::vector of Ort::Value objects that represents the output of the forward pass of the training model.
|
||||||
|
*
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
std::vector<Value> TrainStep(const std::vector<Value>& input_values);
|
||||||
|
|
||||||
|
/** \brief Reset the gradients of all trainable parameters to zero lazily.
|
||||||
|
*
|
||||||
|
* This function sets the internal state of the training session such that the gradients of the trainable
|
||||||
|
* parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are
|
||||||
|
* computed on the next invocation of the next Ort::TrainingSession::TrainStep.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void LazyResetGrad();
|
||||||
|
|
||||||
|
/** \brief Computes the outputs for the eval model for the given inputs
|
||||||
|
*
|
||||||
|
* This function performs an eval step that computes the outputs of the eval model for the given inputs.
|
||||||
|
* The eval step is performed based on the eval model that was provided to the training session.
|
||||||
|
*
|
||||||
|
* \param[in] input_values The user inputs to the eval model.
|
||||||
|
* \return A std::vector of Ort::Value objects that represents the output of the eval pass.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
std::vector<Value> EvalStep(const std::vector<Value>& input_values);
|
||||||
|
|
||||||
|
/** \brief Sets the learning rate for this training session.
|
||||||
|
*
|
||||||
|
* This function allows users to set the learning rate for the training session. The current
|
||||||
|
* learning rate is maintained by the training session and can be overwritten by invoking
|
||||||
|
* this function with the desired learning rate. This function should not be used when a valid
|
||||||
|
* learning rate scheduler is registered. It should be used either to set the learning rate
|
||||||
|
* derived from a custom learning rate scheduler or to set a constant learning rate to be used
|
||||||
|
* throughout the training session.
|
||||||
|
* \note Please note that this function does not set the initial learning rate that may be needed
|
||||||
|
* by the predefined learning rate schedulers. To set the initial learning rate for learning
|
||||||
|
* rate schedulers, please look at the function Ort::TrainingSession::RegisterLinearLRScheduler.
|
||||||
|
*
|
||||||
|
* \param[in] learning_rate Desired learning rate to be set.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void SetLearningRate(float learning_rate);
|
||||||
|
|
||||||
|
/** \brief Gets the current learning rate for this training session.
|
||||||
|
*
|
||||||
|
* This function allows users to get the learning rate for the training session. The current
|
||||||
|
* learning rate is maintained by the training session, and users can query it for the purpose
|
||||||
|
* of implementing their own learning rate schedulers.
|
||||||
|
*
|
||||||
|
* \return float representing the current learning rate.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
float GetLearningRate() const;
|
||||||
|
|
||||||
|
/** \brief Registers a linear learning rate scheduler for the training session.
|
||||||
|
*
|
||||||
|
* Register a linear learning rate scheduler that decays the learning rate by linearly updated
|
||||||
|
* multiplicative factor from the initial learning rate set on the training session to 0. The decay
|
||||||
|
* is performed after the initial warm up phase where the learning rate is linearly incremented
|
||||||
|
* from 0 to the initial learning rate provided.
|
||||||
|
*
|
||||||
|
* \param[in] warmup_step_count Warmup steps for LR warmup.
|
||||||
|
* \param[in] total_step_count Total step count.
|
||||||
|
* \param[in] initial_lr The initial learning rate to be used by the training session.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
|
||||||
|
float initial_lr);
|
||||||
|
|
||||||
|
/** \brief Update the learning rate based on the registered learing rate scheduler.
|
||||||
|
*
|
||||||
|
* Takes a scheduler step that updates the learning rate that is being used by the training session.
|
||||||
|
* This function should typically be called before invoking the optimizer step for each round,
|
||||||
|
* or as determined necessary to update the learning rate being used by the training session.
|
||||||
|
* \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this
|
||||||
|
* function.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void SchedulerStep();
|
||||||
|
|
||||||
|
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
|
||||||
|
*
|
||||||
|
* This function performs the weight update step that updates the trainable parameters such that they
|
||||||
|
* take a step in the direction of their gradients (gradient descent). The optimizer step is performed
|
||||||
|
* based on the optimizer model that was provided to the training session.
|
||||||
|
* The updated parameters are stored inside the training state so that they can be used by the next
|
||||||
|
* Ort::TrainingSession::TrainStep function call.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void OptimizerStep();
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Prepare For Inferencing
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** \brief Export a model that can be used for inferencing.
|
||||||
|
*
|
||||||
|
* If the training session was provided with an eval model, the training session can generate
|
||||||
|
* an inference model if it knows the inference graph outputs. The input inference graph outputs
|
||||||
|
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
|
||||||
|
* The exported model is saved at the path provided and can be used for inferencing with Ort::Session.
|
||||||
|
* \note Note that the function re-loads the eval model from the path provided to Ort::TrainingSession
|
||||||
|
* and expects that this path still be valid.
|
||||||
|
*
|
||||||
|
* \param[in] inference_model_path Path where the inference model should be serialized to.
|
||||||
|
* \param[in] graph_output_names Names of the outputs that are needed in the inference model.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
|
||||||
|
const std::vector<std::string>& graph_output_names);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Model IO Information
|
||||||
|
/// @{
|
||||||
|
/** \brief Retrieves the names of the user inputs for the training and eval models.
|
||||||
|
*
|
||||||
|
* This function returns the names of inputs of the training or eval model that can be associated
|
||||||
|
* with the Ort::Value(s) provided to the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep
|
||||||
|
* function.
|
||||||
|
*
|
||||||
|
* \param[in] training Whether the training model input names are requested or eval model input names.
|
||||||
|
* \return Graph input names for either the training model or the eval model.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
std::vector<std::string> InputNames(const bool training);
|
||||||
|
|
||||||
|
/** \brief Retrieves the names of the user outputs for the training and eval models.
|
||||||
|
*
|
||||||
|
* This function returns the names of outputs of the training or eval model that can be associated
|
||||||
|
* with the Ort::Value(s) returned by the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep
|
||||||
|
* function.
|
||||||
|
*
|
||||||
|
* \param[in] training Whether the training model output names are requested or eval model output names.
|
||||||
|
* \return Graph output names for either the training model or the eval model.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
std::vector<std::string> OutputNames(const bool training);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// \name Accessing The Training Session State
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** \brief Returns a contiguous buffer that holds a copy of all training state parameters
|
||||||
|
*
|
||||||
|
* \param[in] only_trainable Whether to only copy trainable parameters or to copy all parameters.
|
||||||
|
* \return Contiguous buffer to the model parameters.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
Value ToBuffer(const bool only_trainable);
|
||||||
|
|
||||||
|
/** \brief Loads the training session model parameters from a contiguous buffer
|
||||||
|
*
|
||||||
|
* In case the training session was created with a nominal checkpoint, invoking this function is required
|
||||||
|
* to load the updated parameters onto the checkpoint to complete it.
|
||||||
|
*
|
||||||
|
* \param[in] buffer Contiguous buffer to load the parameters from.
|
||||||
|
*/
|
||||||
|
void FromBuffer(Value& buffer);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// \name Training Utilities
|
||||||
|
/// @{
|
||||||
|
/** \brief This function sets the seed for generating random numbers.
|
||||||
|
*
|
||||||
|
* Use this function to generate reproducible results. It should be noted that completely
|
||||||
|
* reproducible results are not guaranteed.
|
||||||
|
*
|
||||||
|
* \param[in] seed Manual seed to use for random number generation.
|
||||||
|
*/
|
||||||
|
void SetSeed(const int64_t seed);
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
} // namespace Ort
|
||||||
|
|
||||||
|
#include "onnxruntime_training_cxx_inline.h"
|
@ -0,0 +1,295 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "onnxruntime_training_c_api.h"
|
||||||
|
#include "onnxruntime_cxx_api.h"
|
||||||
|
|
||||||
|
namespace Ort {
|
||||||
|
|
||||||
|
inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
|
||||||
|
CheckpointState& checkpoint_state,
|
||||||
|
const std::basic_string<ORTCHAR_T>& train_model_path,
|
||||||
|
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path,
|
||||||
|
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path) {
|
||||||
|
ThrowOnError(GetTrainingApi().CreateTrainingSession(
|
||||||
|
env, session_options, checkpoint_state,
|
||||||
|
train_model_path.c_str(),
|
||||||
|
eval_model_path.has_value() ? eval_model_path.value().c_str() : nullptr,
|
||||||
|
optimizer_model_path.has_value() ? optimizer_model_path.value().c_str() : nullptr,
|
||||||
|
&p_));
|
||||||
|
|
||||||
|
ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_));
|
||||||
|
|
||||||
|
ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
|
||||||
|
CheckpointState& checkpoint_state,
|
||||||
|
const std::vector<uint8_t>& train_model_data,
|
||||||
|
const std::vector<uint8_t>& eval_model_data,
|
||||||
|
const std::vector<uint8_t>& optim_model_data) {
|
||||||
|
ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer(
|
||||||
|
env, session_options, checkpoint_state,
|
||||||
|
train_model_data.data(), train_model_data.size(),
|
||||||
|
eval_model_data.data(), eval_model_data.size(),
|
||||||
|
optim_model_data.data(), optim_model_data.size(),
|
||||||
|
&p_));
|
||||||
|
|
||||||
|
ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_));
|
||||||
|
|
||||||
|
ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<Value> TrainingSession::TrainStep(const std::vector<Value>& input_values) {
|
||||||
|
std::vector<Value> output_values;
|
||||||
|
output_values.reserve(training_model_output_count_);
|
||||||
|
for (size_t i = 0; i < training_model_output_count_; i++) output_values.emplace_back(nullptr);
|
||||||
|
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data());
|
||||||
|
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data());
|
||||||
|
RunOptions run_options;
|
||||||
|
ThrowOnError(GetTrainingApi().TrainStep(
|
||||||
|
p_, run_options, input_values.size(), ort_input_values,
|
||||||
|
training_model_output_count_, ort_output_values));
|
||||||
|
|
||||||
|
return output_values;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TrainingSession::LazyResetGrad() {
|
||||||
|
ThrowOnError(GetTrainingApi().LazyResetGrad(p_));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<Value> TrainingSession::EvalStep(const std::vector<Value>& input_values) {
|
||||||
|
std::vector<Value> output_values;
|
||||||
|
output_values.reserve(eval_model_output_count_);
|
||||||
|
for (size_t i = 0; i < eval_model_output_count_; i++) output_values.emplace_back(nullptr);
|
||||||
|
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data());
|
||||||
|
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data());
|
||||||
|
RunOptions run_options;
|
||||||
|
ThrowOnError(GetTrainingApi().EvalStep(
|
||||||
|
p_, run_options, input_values.size(), ort_input_values,
|
||||||
|
eval_model_output_count_, ort_output_values));
|
||||||
|
|
||||||
|
return output_values;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TrainingSession::SetLearningRate(float learning_rate) {
|
||||||
|
ThrowOnError(GetTrainingApi().SetLearningRate(p_, learning_rate));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline float TrainingSession::GetLearningRate() const {
|
||||||
|
float learning_rate = 0;
|
||||||
|
ThrowOnError(GetTrainingApi().GetLearningRate(p_, &learning_rate));
|
||||||
|
return learning_rate;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TrainingSession::RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
|
||||||
|
float initial_lr) {
|
||||||
|
ThrowOnError(GetTrainingApi().RegisterLinearLRScheduler(p_, warmup_step_count, total_step_count,
|
||||||
|
initial_lr));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TrainingSession::SchedulerStep() {
|
||||||
|
ThrowOnError(GetTrainingApi().SchedulerStep(p_));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TrainingSession::OptimizerStep() {
|
||||||
|
RunOptions run_options;
|
||||||
|
ThrowOnError(GetTrainingApi().OptimizerStep(p_, run_options));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<std::string> TrainingSession::InputNames(const bool training) {
|
||||||
|
auto& input_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputCount
|
||||||
|
: GetTrainingApi().TrainingSessionGetEvalModelInputCount;
|
||||||
|
auto& input_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputName
|
||||||
|
: GetTrainingApi().TrainingSessionGetEvalModelInputName;
|
||||||
|
|
||||||
|
size_t input_count = 0;
|
||||||
|
ThrowOnError(input_count_function(p_, &input_count));
|
||||||
|
std::vector<std::string> input_names(input_count);
|
||||||
|
AllocatorWithDefaultOptions allocator;
|
||||||
|
for (size_t index = 0; index < input_count; ++index) {
|
||||||
|
char* input_name;
|
||||||
|
ThrowOnError(input_name_function(p_, index, allocator, &input_name));
|
||||||
|
input_names[index] = std::string(input_name);
|
||||||
|
allocator.Free(input_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
return input_names;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<std::string> TrainingSession::OutputNames(const bool training) {
|
||||||
|
auto& output_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputCount
|
||||||
|
: GetTrainingApi().TrainingSessionGetEvalModelOutputCount;
|
||||||
|
auto& output_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputName
|
||||||
|
: GetTrainingApi().TrainingSessionGetEvalModelOutputName;
|
||||||
|
|
||||||
|
size_t output_count = 0;
|
||||||
|
ThrowOnError(output_count_function(p_, &output_count));
|
||||||
|
std::vector<std::string> output_names(output_count);
|
||||||
|
AllocatorWithDefaultOptions allocator;
|
||||||
|
for (size_t index = 0; index < output_count; ++index) {
|
||||||
|
char* output_name;
|
||||||
|
ThrowOnError(output_name_function(p_, index, allocator, &output_name));
|
||||||
|
output_names[index] = std::string(output_name);
|
||||||
|
allocator.Free(output_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
return output_names;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Value TrainingSession::ToBuffer(const bool only_trainable) {
|
||||||
|
size_t buffer_size = 0U;
|
||||||
|
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &buffer_size, only_trainable));
|
||||||
|
|
||||||
|
std::array<int64_t, 1> buffer_shape{static_cast<int64_t>(buffer_size)};
|
||||||
|
|
||||||
|
AllocatorWithDefaultOptions allocator;
|
||||||
|
Value buffer = Value::CreateTensor(allocator, buffer_shape.data(), 1U,
|
||||||
|
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||||
|
|
||||||
|
ThrowOnError(GetTrainingApi().CopyParametersToBuffer(p_, buffer, only_trainable));
|
||||||
|
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TrainingSession::FromBuffer(Value& buffer) {
|
||||||
|
if (!buffer.IsTensor()) {
|
||||||
|
ThrowStatus(Status("Incorrect buffer received. Expected a tensor buffer.", OrtErrorCode::ORT_INVALID_ARGUMENT));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tensor_info = buffer.GetTensorTypeAndShapeInfo();
|
||||||
|
auto buffer_shape = tensor_info.GetShape();
|
||||||
|
|
||||||
|
if (buffer_shape.size() != 1U) {
|
||||||
|
ThrowStatus(Status("Incorrect buffer received. Expected a contiguous tensor buffer.",
|
||||||
|
OrtErrorCode::ORT_INVALID_ARGUMENT));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto buffer_size = buffer_shape.front();
|
||||||
|
|
||||||
|
size_t session_buffer_size = 0U;
|
||||||
|
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false));
|
||||||
|
|
||||||
|
if (buffer_size == static_cast<int64_t>(session_buffer_size)) {
|
||||||
|
ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t session_buffer_size_trainable_only = 0U;
|
||||||
|
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size_trainable_only, true));
|
||||||
|
|
||||||
|
if (buffer_size == static_cast<int64_t>(session_buffer_size_trainable_only)) {
|
||||||
|
ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, true));
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
ThrowStatus(Status("Incorrect buffer size received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint) {
|
||||||
|
OrtCheckpointState* checkpoint_state;
|
||||||
|
ThrowOnError(GetTrainingApi().LoadCheckpoint(path_to_checkpoint.c_str(), &checkpoint_state));
|
||||||
|
return CheckpointState(checkpoint_state);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline CheckpointState CheckpointState::LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer) {
|
||||||
|
OrtCheckpointState* checkpoint_state;
|
||||||
|
ThrowOnError(GetTrainingApi().LoadCheckpointFromBuffer(buffer.data(), buffer.size(), &checkpoint_state));
|
||||||
|
return CheckpointState(checkpoint_state);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void CheckpointState::SaveCheckpoint(const CheckpointState& checkpoint_states,
|
||||||
|
const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
|
||||||
|
const bool include_optimizer_state) {
|
||||||
|
ThrowOnError(GetTrainingApi().SaveCheckpoint(checkpoint_states, path_to_checkpoint.c_str(),
|
||||||
|
include_optimizer_state));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TrainingSession::ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
|
||||||
|
const std::vector<std::string>& graph_output_names) {
|
||||||
|
std::vector<const char*> output_names;
|
||||||
|
output_names.reserve(graph_output_names.size());
|
||||||
|
for (const auto& output_name : graph_output_names) {
|
||||||
|
output_names.push_back(output_name.c_str());
|
||||||
|
}
|
||||||
|
ThrowOnError(GetTrainingApi().ExportModelForInferencing(
|
||||||
|
p_, inference_model_path.c_str(), graph_output_names.size(), output_names.data()));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void SetSeed(const int64_t seed) {
|
||||||
|
ThrowOnError(GetTrainingApi().SetSeed(seed));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void CheckpointState::AddProperty(const std::string& property_name, const Property& property_value) {
|
||||||
|
if (std::holds_alternative<int64_t>(property_value)) {
|
||||||
|
int64_t value = std::get<int64_t>(property_value);
|
||||||
|
void* value_p = &value;
|
||||||
|
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtIntProperty, value_p));
|
||||||
|
} else if (std::holds_alternative<float>(property_value)) {
|
||||||
|
float value = std::get<float>(property_value);
|
||||||
|
void* value_p = &value;
|
||||||
|
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtFloatProperty, value_p));
|
||||||
|
} else if (std::holds_alternative<std::string>(property_value)) {
|
||||||
|
std::string value = std::get<std::string>(property_value);
|
||||||
|
auto buffer = std::make_unique<char[]>(value.length() + 1);
|
||||||
|
memcpy(buffer.get(), value.c_str(), value.length());
|
||||||
|
// AddProperty takes a char* and calls PropertyBag::AddProperty which takes a std::string. The data will be
|
||||||
|
// copied at that point so buffer can free the local allocation once the call is made.
|
||||||
|
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtStringProperty,
|
||||||
|
buffer.get()));
|
||||||
|
} else {
|
||||||
|
ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Property CheckpointState::GetProperty(const std::string& property_name) {
|
||||||
|
void* property_value = nullptr;
|
||||||
|
OrtPropertyType property_type;
|
||||||
|
|
||||||
|
AllocatorWithDefaultOptions allocator;
|
||||||
|
ThrowOnError(GetTrainingApi().GetProperty(p_, property_name.c_str(), allocator, &property_type, &property_value));
|
||||||
|
|
||||||
|
Property property;
|
||||||
|
|
||||||
|
switch (property_type) {
|
||||||
|
case OrtPropertyType::OrtIntProperty: {
|
||||||
|
auto value_p = reinterpret_cast<int64_t*>(property_value);
|
||||||
|
property = *value_p;
|
||||||
|
allocator.Free(property_value);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case OrtPropertyType::OrtFloatProperty: {
|
||||||
|
auto value_p = reinterpret_cast<float*>(property_value);
|
||||||
|
property = *value_p;
|
||||||
|
allocator.Free(property_value);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case OrtPropertyType::OrtStringProperty: {
|
||||||
|
auto value_p = reinterpret_cast<char*>(property_value);
|
||||||
|
property = std::string(value_p);
|
||||||
|
allocator.Free(property_value);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return property;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) {
|
||||||
|
ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Value CheckpointState::GetParameter(const std::string& parameter_name) {
|
||||||
|
AllocatorWithDefaultOptions allocator;
|
||||||
|
OrtValue* parameter;
|
||||||
|
ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter));
|
||||||
|
|
||||||
|
return Value{parameter};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace Ort
|
@ -0,0 +1,18 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace onnxruntime {
|
||||||
|
|
||||||
|
// data types for execution provider options
|
||||||
|
|
||||||
|
using ProviderOptions = std::unordered_map<std::string, std::string>;
|
||||||
|
using ProviderOptionsVector = std::vector<ProviderOptions>;
|
||||||
|
using ProviderOptionsMap = std::unordered_map<std::string, ProviderOptions>;
|
||||||
|
|
||||||
|
} // namespace onnxruntime
|
@ -0,0 +1 @@
|
|||||||
|
libonnxruntime.so.1.18.1
|
BIN
onnxruntime/onnxruntime-linux-x64-gpu-1.18.1/lib/libonnxruntime.so.1.18.1
Executable file
BIN
onnxruntime/onnxruntime-linux-x64-gpu-1.18.1/lib/libonnxruntime.so.1.18.1
Executable file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
onnxruntime/onnxruntime-linux-x64-gpu-cuda12-1.18.1.tgz
Normal file
BIN
onnxruntime/onnxruntime-linux-x64-gpu-cuda12-1.18.1.tgz
Normal file
Binary file not shown.
Reference in New Issue
Block a user