// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include "onnxruntime_training_c_api.h" #include "onnxruntime_cxx_api.h" namespace Ort { inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, const std::basic_string& train_model_path, const std::optional>& eval_model_path, const std::optional>& optimizer_model_path) { ThrowOnError(GetTrainingApi().CreateTrainingSession( env, session_options, checkpoint_state, train_model_path.c_str(), eval_model_path.has_value() ? eval_model_path.value().c_str() : nullptr, optimizer_model_path.has_value() ? optimizer_model_path.value().c_str() : nullptr, &p_)); ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_)); ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); } inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, const std::vector& train_model_data, const std::vector& eval_model_data, const std::vector& optim_model_data) { ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer( env, session_options, checkpoint_state, train_model_data.data(), train_model_data.size(), eval_model_data.data(), eval_model_data.size(), optim_model_data.data(), optim_model_data.size(), &p_)); ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_)); ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); } inline std::vector TrainingSession::TrainStep(const std::vector& input_values) { std::vector output_values; output_values.reserve(training_model_output_count_); for (size_t i = 0; i < training_model_output_count_; i++) output_values.emplace_back(nullptr); auto ort_input_values = reinterpret_cast(input_values.data()); auto ort_output_values = reinterpret_cast(output_values.data()); RunOptions run_options; ThrowOnError(GetTrainingApi().TrainStep( p_, run_options, input_values.size(), ort_input_values, training_model_output_count_, ort_output_values)); return output_values; } inline void TrainingSession::LazyResetGrad() { ThrowOnError(GetTrainingApi().LazyResetGrad(p_)); } inline std::vector TrainingSession::EvalStep(const std::vector& input_values) { std::vector output_values; output_values.reserve(eval_model_output_count_); for (size_t i = 0; i < eval_model_output_count_; i++) output_values.emplace_back(nullptr); auto ort_input_values = reinterpret_cast(input_values.data()); auto ort_output_values = reinterpret_cast(output_values.data()); RunOptions run_options; ThrowOnError(GetTrainingApi().EvalStep( p_, run_options, input_values.size(), ort_input_values, eval_model_output_count_, ort_output_values)); return output_values; } inline void TrainingSession::SetLearningRate(float learning_rate) { ThrowOnError(GetTrainingApi().SetLearningRate(p_, learning_rate)); } inline float TrainingSession::GetLearningRate() const { float learning_rate = 0; ThrowOnError(GetTrainingApi().GetLearningRate(p_, &learning_rate)); return learning_rate; } inline void TrainingSession::RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count, float initial_lr) { ThrowOnError(GetTrainingApi().RegisterLinearLRScheduler(p_, warmup_step_count, total_step_count, initial_lr)); } inline void TrainingSession::SchedulerStep() { ThrowOnError(GetTrainingApi().SchedulerStep(p_)); } inline void TrainingSession::OptimizerStep() { RunOptions run_options; ThrowOnError(GetTrainingApi().OptimizerStep(p_, run_options)); } inline std::vector TrainingSession::InputNames(const bool training) { auto& input_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputCount : GetTrainingApi().TrainingSessionGetEvalModelInputCount; auto& input_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputName : GetTrainingApi().TrainingSessionGetEvalModelInputName; size_t input_count = 0; ThrowOnError(input_count_function(p_, &input_count)); std::vector input_names(input_count); AllocatorWithDefaultOptions allocator; for (size_t index = 0; index < input_count; ++index) { char* input_name; ThrowOnError(input_name_function(p_, index, allocator, &input_name)); input_names[index] = std::string(input_name); allocator.Free(input_name); } return input_names; } inline std::vector TrainingSession::OutputNames(const bool training) { auto& output_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputCount : GetTrainingApi().TrainingSessionGetEvalModelOutputCount; auto& output_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputName : GetTrainingApi().TrainingSessionGetEvalModelOutputName; size_t output_count = 0; ThrowOnError(output_count_function(p_, &output_count)); std::vector output_names(output_count); AllocatorWithDefaultOptions allocator; for (size_t index = 0; index < output_count; ++index) { char* output_name; ThrowOnError(output_name_function(p_, index, allocator, &output_name)); output_names[index] = std::string(output_name); allocator.Free(output_name); } return output_names; } inline Value TrainingSession::ToBuffer(const bool only_trainable) { size_t buffer_size = 0U; ThrowOnError(GetTrainingApi().GetParametersSize(p_, &buffer_size, only_trainable)); std::array buffer_shape{static_cast(buffer_size)}; AllocatorWithDefaultOptions allocator; Value buffer = Value::CreateTensor(allocator, buffer_shape.data(), 1U, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); ThrowOnError(GetTrainingApi().CopyParametersToBuffer(p_, buffer, only_trainable)); return buffer; } inline void TrainingSession::FromBuffer(Value& buffer) { if (!buffer.IsTensor()) { ThrowStatus(Status("Incorrect buffer received. Expected a tensor buffer.", OrtErrorCode::ORT_INVALID_ARGUMENT)); } auto tensor_info = buffer.GetTensorTypeAndShapeInfo(); auto buffer_shape = tensor_info.GetShape(); if (buffer_shape.size() != 1U) { ThrowStatus(Status("Incorrect buffer received. Expected a contiguous tensor buffer.", OrtErrorCode::ORT_INVALID_ARGUMENT)); } auto buffer_size = buffer_shape.front(); size_t session_buffer_size = 0U; ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false)); if (buffer_size == static_cast(session_buffer_size)) { ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false)); return; } size_t session_buffer_size_trainable_only = 0U; ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size_trainable_only, true)); if (buffer_size == static_cast(session_buffer_size_trainable_only)) { ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, true)); return; } else { ThrowStatus(Status("Incorrect buffer size received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); } } inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string& path_to_checkpoint) { OrtCheckpointState* checkpoint_state; ThrowOnError(GetTrainingApi().LoadCheckpoint(path_to_checkpoint.c_str(), &checkpoint_state)); return CheckpointState(checkpoint_state); } inline CheckpointState CheckpointState::LoadCheckpointFromBuffer(const std::vector& buffer) { OrtCheckpointState* checkpoint_state; ThrowOnError(GetTrainingApi().LoadCheckpointFromBuffer(buffer.data(), buffer.size(), &checkpoint_state)); return CheckpointState(checkpoint_state); } inline void CheckpointState::SaveCheckpoint(const CheckpointState& checkpoint_states, const std::basic_string& path_to_checkpoint, const bool include_optimizer_state) { ThrowOnError(GetTrainingApi().SaveCheckpoint(checkpoint_states, path_to_checkpoint.c_str(), include_optimizer_state)); } inline void TrainingSession::ExportModelForInferencing(const std::basic_string& inference_model_path, const std::vector& graph_output_names) { std::vector output_names; output_names.reserve(graph_output_names.size()); for (const auto& output_name : graph_output_names) { output_names.push_back(output_name.c_str()); } ThrowOnError(GetTrainingApi().ExportModelForInferencing( p_, inference_model_path.c_str(), graph_output_names.size(), output_names.data())); } inline void SetSeed(const int64_t seed) { ThrowOnError(GetTrainingApi().SetSeed(seed)); } inline void CheckpointState::AddProperty(const std::string& property_name, const Property& property_value) { if (std::holds_alternative(property_value)) { int64_t value = std::get(property_value); void* value_p = &value; ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtIntProperty, value_p)); } else if (std::holds_alternative(property_value)) { float value = std::get(property_value); void* value_p = &value; ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtFloatProperty, value_p)); } else if (std::holds_alternative(property_value)) { std::string value = std::get(property_value); auto buffer = std::make_unique(value.length() + 1); memcpy(buffer.get(), value.c_str(), value.length()); // AddProperty takes a char* and calls PropertyBag::AddProperty which takes a std::string. The data will be // copied at that point so buffer can free the local allocation once the call is made. ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtStringProperty, buffer.get())); } else { ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); } } inline Property CheckpointState::GetProperty(const std::string& property_name) { void* property_value = nullptr; OrtPropertyType property_type; AllocatorWithDefaultOptions allocator; ThrowOnError(GetTrainingApi().GetProperty(p_, property_name.c_str(), allocator, &property_type, &property_value)); Property property; switch (property_type) { case OrtPropertyType::OrtIntProperty: { auto value_p = reinterpret_cast(property_value); property = *value_p; allocator.Free(property_value); break; } case OrtPropertyType::OrtFloatProperty: { auto value_p = reinterpret_cast(property_value); property = *value_p; allocator.Free(property_value); break; } case OrtPropertyType::OrtStringProperty: { auto value_p = reinterpret_cast(property_value); property = std::string(value_p); allocator.Free(property_value); break; } default: { ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); break; } } return property; } inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) { ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter)); } inline Value CheckpointState::GetParameter(const std::string& parameter_name) { AllocatorWithDefaultOptions allocator; OrtValue* parameter; ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter)); return Value{parameter}; } } // namespace Ort