From 1440dcbb390c01fb1692f5aa4427d538f6a04521 Mon Sep 17 00:00:00 2001 From: carzh Date: Mon, 15 Jul 2024 17:16:44 -0700 Subject: [PATCH 01/12] experimentation + working export --- onnxruntime/core/session/inference_session.h | 4 ++ .../orttraining/training_api/module.cc | 48 ++++++++++++++----- orttraining/orttraining/training_api/module.h | 4 +- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index e1cd085d2c271..64e3e35cb7f2b 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -403,6 +403,10 @@ class InferenceSession { int32_t partial_graph_index); #endif +#ifdef ENABLE_TRAINING_APIS + std::shared_ptr GetModel() noexcept { return model_; }; +#endif + /** * @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK. * @note lifetime of the returned pointer is valid as long as the Session object is live. diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index dc724fbae48eb..25d414734a90d 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -56,7 +56,10 @@ Status RemoveUnusedNodes(Graph& inference_graph, InlinedVector& for (size_t idx = node_indices.size(); idx > 0; --idx) { const NodeIndex node_index = idx - 1; auto* node = inference_graph.GetNode(node_index); - if (!reachable_nodes.count(node)) { + if (!node) { + inference_graph.RemoveNode(node_index); + } + else if (!reachable_nodes.count(node)) { graph_utils::RemoveNodeOutputEdges(inference_graph, *node); inference_graph.RemoveNode(node_index); } @@ -65,6 +68,20 @@ Status RemoveUnusedNodes(Graph& inference_graph, InlinedVector& return Status::OK(); } +//TODO: REMOVE THIS METHOD BEFORE YOUR PR ITS JUST FOR DEBUGGING PURPOSES +Status RemoveThisMethodBeforeYourPR(Graph& inference_graph) { + GraphViewer graph_viewer(inference_graph); + const auto node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (size_t idx = node_indices.size(); idx > 0; --idx) { + const NodeIndex node_index = idx - 1; + auto* node = inference_graph.GetNode(node_index); + if (node->Name().empty()) { + inference_graph.RemoveNode(node_index); + } + } + + return Status::OK(); +} Status TransformModelOutputsForInference(Graph& inference_graph, gsl::span inference_graph_outputs) { // Model is updated to remove any outputs that are not defined in inference_graph_outputs. Nodes @@ -388,6 +405,11 @@ Module::Module(const ModelIdentifiers& model_identifiers, } ORT_THROW_IF_ERROR(eval_sess_->Initialize()); utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); + // TODO: remove this + // std::shared_ptr inference_model = eval_sess_->GetModel(); + // Graph& inference_graph = inference_model->MainGraph(); + + // ORT_THROW_IF_ERROR(RemoveThisMethodBeforeYourPR(inference_graph)); // Eval model validation // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval @@ -412,6 +434,8 @@ Module::Module(const ModelIdentifiers& model_identifiers, eval_user_input_count_ = eval_user_input_names.size(); eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); + // TODO: remove this + // ORT_THROW_IF_ERROR(RemoveThisMethodBeforeYourPR(inference_graph)); // Keep a copy of the eval model path to be able to later export the model for inferencing. // The inference model will be reconstructed from the eval model. // TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer. @@ -613,6 +637,7 @@ Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool tr } Status Module::LazyResetGrad() { + ORT_RETURN_IF(finished_training_, "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); accumulate_gradient_ = false; return Status::OK(); } @@ -620,6 +645,7 @@ Status Module::LazyResetGrad() { Status Module::TrainStep(const std::vector& inputs, std::vector& outputs) { ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, "Cannot perform TrainStep with a nominal state. Please load the model parameters first."); + ORT_RETURN_IF(finished_training_, "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); std::vector> params; std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); @@ -642,6 +668,7 @@ Status Module::TrainStep(const std::vector& inputs, std::vector& inputs, std::vector& outputs) { ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, "Cannot perform EvalStep with a nominal state. Please load the model parameters first."); + ORT_RETURN_IF(finished_training_, "Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized."); std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); @@ -655,26 +682,25 @@ Status Module::EvalStep(const std::vector& inputs, std::vector graph_output_names) const { + gsl::span graph_output_names) { ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, "Cannot export the model with a nominal state. Please load the model parameters first."); - ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(), - "Eval model was not provided. Cannot export a model for inferencing."); - ONNX_NAMESPACE::ModelProto eval_model; - ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model)); + // Once finished_training is set to true, will no longer be able to train or evaluate with this module + // since the eval session graph will have been modified. + finished_training_ = true; - // Clone the eval mode into an inference onnxruntime::Model. - std::shared_ptr inference_model; - ORT_RETURN_IF_ERROR(Model::Load(eval_model, inference_model, nullptr, logging::LoggingManager::DefaultLogger())); + // Model& inference_model = const_cast(eval_sess_->GetModel()); + std::shared_ptr inference_model = eval_sess_->GetModel(); + Graph& inference_graph = inference_model->MainGraph(); // The cloned model's outputs are transformed such that the model has outputs as defined by graph_output_names // Any nodes not contributing to the inference outputs will be pruned. - ORT_THROW_IF_ERROR(TransformModelOutputsForInference(inference_model->MainGraph(), graph_output_names)); + ORT_THROW_IF_ERROR(TransformModelOutputsForInference(inference_graph, graph_output_names)); // The cloned model's inputs are transformed such that the model has only user defined inputs. All parameters // are moved to be constant initializers for the model. - ORT_RETURN_IF_ERROR(TransformModelInputsForInference(inference_model->MainGraph(), + ORT_RETURN_IF_ERROR(TransformModelInputsForInference(inference_graph, state_->module_checkpoint_state.named_parameters, eval_sess_->GetDataTransferManager())); diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 917887404217f..51b469bae03d1 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -140,10 +140,11 @@ struct Module { #if !defined(ORT_MINIMAL_BUILD) // Load the eval model from eval_model_path_or_bytes and transform it for the purpose of // inferencing, and serialize to given path. + // This function modifies the graph stored with the eval session & marks the module as done training. // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, // and the state has not been loaded yet, then this function will return an error. Status ExportModelForInferencing(const std::string& inference_model_path, - gsl::span graph_output_names) const; + gsl::span graph_output_names); #endif // Returns the user input count for training graph @@ -167,6 +168,7 @@ struct Module { private: std::unique_ptr train_sess_{nullptr}; std::unique_ptr eval_sess_{nullptr}; + bool finished_training_ = false; struct TrainInputNames { private: From 3e99612cd67cd30fb86c31cccd58537aa177135e Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 17 Jul 2024 09:43:20 -0700 Subject: [PATCH 02/12] updated module.cc to attempt to resolve bugs + added to py binding tests --- .../python/orttraining_test_ort_apis_py_bindings.py | 3 ++- orttraining/orttraining/training_api/module.cc | 11 ++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py index 68b3fa2176944..d035e2087031b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py @@ -16,7 +16,7 @@ from orttraining_test_ort_apis_onnxblock import _get_models import onnxruntime.training.onnxblock as onnxblock -from onnxruntime import OrtValue, SessionOptions +from onnxruntime import OrtValue, SessionOptions, InferenceSession from onnxruntime.training import artifacts from onnxruntime.training.api import CheckpointState, LinearLRScheduler, Module, Optimizer @@ -283,6 +283,7 @@ def test_export_model_for_inferencing(): inference_model_file_path = os.path.join(temp_dir, "inference_model.onnx") model.export_model_for_inferencing(inference_model_file_path, ["output-0"]) assert os.path.exists(inference_model_file_path) + InferenceSession(inference_model_file_path) def test_cuda_execution_provider(): diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 25d414734a90d..e23a5add4bbba 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -56,10 +56,7 @@ Status RemoveUnusedNodes(Graph& inference_graph, InlinedVector& for (size_t idx = node_indices.size(); idx > 0; --idx) { const NodeIndex node_index = idx - 1; auto* node = inference_graph.GetNode(node_index); - if (!node) { - inference_graph.RemoveNode(node_index); - } - else if (!reachable_nodes.count(node)) { + if (node && !reachable_nodes.count(node)) { graph_utils::RemoveNodeOutputEdges(inference_graph, *node); inference_graph.RemoveNode(node_index); } @@ -101,9 +98,9 @@ Status TransformModelOutputsForInference(Graph& inference_graph, // Set the inference graph outputs, and remove any unused nodes. inference_graph.SetOutputs(inference_graph_output_node_args); - ORT_RETURN_IF_ERROR(RemoveUnusedNodes(inference_graph, inference_graph_output_node_args)); + // ORT_RETURN_IF_ERROR(RemoveUnusedNodes(inference_graph, inference_graph_output_node_args)); - ORT_RETURN_IF_ERROR(inference_graph.Resolve()); + ORT_THROW_IF_ERROR(inference_graph.Resolve()); return Status::OK(); } @@ -131,7 +128,7 @@ Status TransformModelInputsForInference(Graph& inference_graph, } inference_graph.SetInputs(user_graph_inputs); - ORT_RETURN_IF_ERROR(inference_graph.Resolve()); + ORT_THROW_IF_ERROR(inference_graph.Resolve()); return Status::OK(); } From 46734ae03696f1ffb781662389cb86b3c7cf99ee Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 17 Jul 2024 13:51:19 -0700 Subject: [PATCH 03/12] added more unit test + continued trying to resolve bug --- .../training_api/core/training_api_tests.cc | 78 +++++++++++++++++++ .../orttraining/training_api/module.cc | 4 +- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index 90c97eed0c6d3..e9729558fcd35 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -32,6 +32,19 @@ namespace { constexpr int64_t TOTAL_STEP_COUNT = 100; constexpr float INITIAL_LR = 1e-3f; +std::vector ReadFileIntoBuffer(const std::string& file_path) { + size_t num_bytes = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(file_path.c_str(), num_bytes)); + std::vector buffer(num_bytes); + + std::ifstream bytes_stream(file_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(buffer.data()), num_bytes); + + ASSERT_TRUE(bytes_stream); + + return buffer; +} + /** * @brief Create a Fake Optimizer Checkpoint State On CPU. * @@ -138,6 +151,66 @@ void TestModuleExport(const std::vector>& pr RunInferenceSession(*env, inference_model_path); } +void TestModuleExportFromBuffer(const std::vector>& providers) { + auto training_model_uri = MODEL_FOLDER "training_model.onnx"; + auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; + + onnxruntime::training::api::CheckpointState state; + auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt"; + // Load checkpoint, eval model, and training model into buffers + std::vector checkpoint_bytes = ReadFileIntoBuffer(checkpoint_to_load_path); + std::vector training_model_bytes = ReadFileIntoBuffer(training_model_uri); + std::vector eval_model_bytes = ReadFileIntoBuffer(eval_model_uri); + + ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpointFromBuffer(checkpoint_bytes, state)); + + // load training and eval model into buffers + std::unique_ptr env; + ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(training_model_bytes, + std::optional>(eval_model_bytes), + std::nullopt); + auto model = std::make_unique( + model_identifier, &state, onnxruntime::SessionOptions(), + *env, providers); + + auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir"); + if (Env::Default().FolderExists(test_dir)) { + ORT_ENFORCE(Env::Default().DeleteFolder(test_dir).IsOK()); + } + onnxruntime::test::TemporaryDirectory tmp_dir{test_dir}; + PathString inference_model_path{ + ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("inference_model.onnx"))}; + + std::vector graph_output_names({"output-0"}); + ASSERT_STATUS_OK(model->ExportModelForInferencing(ToUTF8String(inference_model_path), graph_output_names)); + + // Load model + ONNX_NAMESPACE::ModelProto eval_model; + ONNX_NAMESPACE::ModelProto inference_model; + ORT_THROW_IF_ERROR(Model::Load(eval_model_uri, eval_model)); + ORT_THROW_IF_ERROR(Model::Load(inference_model_path, inference_model)); + + // Check it has only one graph input + ASSERT_EQ(eval_model.graph().input().size(), 6); + ASSERT_EQ(inference_model.graph().input().size(), 1); + ASSERT_EQ(inference_model.graph().input()[0].name(), "input-0"); + + // Check that it does not have any node which has op type SoftmaxCrossEntropyLoss + auto softmaxceloss_node_found = [](auto& model) -> bool { + for (auto& node : model.graph().node()) { + if (node.op_type() == "SoftmaxCrossEntropyLoss") { + return true; + } + } + return false; + }; + ASSERT_EQ(softmaxceloss_node_found(eval_model), true); + ASSERT_EQ(softmaxceloss_node_found(inference_model), false); + + RunInferenceSession(*env, inference_model_path); +} + void TestModuleExportWithExternalData(const std::vector>& providers) { auto training_model_uri = MODEL_FOLDER "training_model.onnx"; auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; @@ -495,6 +568,11 @@ TEST(TrainingApiTest, ModuleExportModelForInferencingCPU) { TestModuleExport(providers); } +TEST(TrainingApiTest, ModuleFromBufferExportModelForInferencingCPU) { + std::vector> providers{onnxruntime::test::DefaultCpuExecutionProvider()}; + TestModuleExportFromBuffer(providers); +} + TEST(TrainingApiTest, ModuleExportModelForInferencingCPU_WithExternalData) { std::vector> providers{onnxruntime::test::DefaultCpuExecutionProvider()}; TestModuleExportWithExternalData(providers); diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index e23a5add4bbba..56c46c2869564 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -54,7 +54,7 @@ Status RemoveUnusedNodes(Graph& inference_graph, InlinedVector& GraphViewer graph_viewer(inference_graph); const auto node_indices = graph_viewer.GetNodesInTopologicalOrder(); for (size_t idx = node_indices.size(); idx > 0; --idx) { - const NodeIndex node_index = idx - 1; + const NodeIndex node_index = node_indices[idx - 1]; auto* node = inference_graph.GetNode(node_index); if (node && !reachable_nodes.count(node)) { graph_utils::RemoveNodeOutputEdges(inference_graph, *node); @@ -98,7 +98,7 @@ Status TransformModelOutputsForInference(Graph& inference_graph, // Set the inference graph outputs, and remove any unused nodes. inference_graph.SetOutputs(inference_graph_output_node_args); - // ORT_RETURN_IF_ERROR(RemoveUnusedNodes(inference_graph, inference_graph_output_node_args)); + ORT_RETURN_IF_ERROR(RemoveUnusedNodes(inference_graph, inference_graph_output_node_args)); ORT_THROW_IF_ERROR(inference_graph.Resolve()); From 2ed4dfebbbac6cd8ac01fba55819b936a57ce477 Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 19 Jul 2024 16:44:48 -0700 Subject: [PATCH 04/12] added working unit test + additional throw statements + cleaned up --- .../orttraining_test_ort_apis_py_bindings.py | 2 +- .../training_api/core/training_api_tests.cc | 78 ------------------- .../training_api/core/training_capi_tests.cc | 69 ++++++++++++++++ .../orttraining/training_api/module.cc | 26 +++---- orttraining/orttraining/training_api/module.h | 6 +- .../training_api/training_session.cc | 8 +- .../training_api/training_session.h | 8 +- 7 files changed, 90 insertions(+), 107 deletions(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py index d035e2087031b..3dee3a4169942 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py @@ -16,7 +16,7 @@ from orttraining_test_ort_apis_onnxblock import _get_models import onnxruntime.training.onnxblock as onnxblock -from onnxruntime import OrtValue, SessionOptions, InferenceSession +from onnxruntime import InferenceSession, OrtValue, SessionOptions from onnxruntime.training import artifacts from onnxruntime.training.api import CheckpointState, LinearLRScheduler, Module, Optimizer diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index e9729558fcd35..90c97eed0c6d3 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -32,19 +32,6 @@ namespace { constexpr int64_t TOTAL_STEP_COUNT = 100; constexpr float INITIAL_LR = 1e-3f; -std::vector ReadFileIntoBuffer(const std::string& file_path) { - size_t num_bytes = 0; - ASSERT_STATUS_OK(Env::Default().GetFileLength(file_path.c_str(), num_bytes)); - std::vector buffer(num_bytes); - - std::ifstream bytes_stream(file_path, std::ifstream::in | std::ifstream::binary); - bytes_stream.read(reinterpret_cast(buffer.data()), num_bytes); - - ASSERT_TRUE(bytes_stream); - - return buffer; -} - /** * @brief Create a Fake Optimizer Checkpoint State On CPU. * @@ -151,66 +138,6 @@ void TestModuleExport(const std::vector>& pr RunInferenceSession(*env, inference_model_path); } -void TestModuleExportFromBuffer(const std::vector>& providers) { - auto training_model_uri = MODEL_FOLDER "training_model.onnx"; - auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; - - onnxruntime::training::api::CheckpointState state; - auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt"; - // Load checkpoint, eval model, and training model into buffers - std::vector checkpoint_bytes = ReadFileIntoBuffer(checkpoint_to_load_path); - std::vector training_model_bytes = ReadFileIntoBuffer(training_model_uri); - std::vector eval_model_bytes = ReadFileIntoBuffer(eval_model_uri); - - ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpointFromBuffer(checkpoint_bytes, state)); - - // load training and eval model into buffers - std::unique_ptr env; - ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model_identifier = ModelIdentifiers(training_model_bytes, - std::optional>(eval_model_bytes), - std::nullopt); - auto model = std::make_unique( - model_identifier, &state, onnxruntime::SessionOptions(), - *env, providers); - - auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir"); - if (Env::Default().FolderExists(test_dir)) { - ORT_ENFORCE(Env::Default().DeleteFolder(test_dir).IsOK()); - } - onnxruntime::test::TemporaryDirectory tmp_dir{test_dir}; - PathString inference_model_path{ - ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("inference_model.onnx"))}; - - std::vector graph_output_names({"output-0"}); - ASSERT_STATUS_OK(model->ExportModelForInferencing(ToUTF8String(inference_model_path), graph_output_names)); - - // Load model - ONNX_NAMESPACE::ModelProto eval_model; - ONNX_NAMESPACE::ModelProto inference_model; - ORT_THROW_IF_ERROR(Model::Load(eval_model_uri, eval_model)); - ORT_THROW_IF_ERROR(Model::Load(inference_model_path, inference_model)); - - // Check it has only one graph input - ASSERT_EQ(eval_model.graph().input().size(), 6); - ASSERT_EQ(inference_model.graph().input().size(), 1); - ASSERT_EQ(inference_model.graph().input()[0].name(), "input-0"); - - // Check that it does not have any node which has op type SoftmaxCrossEntropyLoss - auto softmaxceloss_node_found = [](auto& model) -> bool { - for (auto& node : model.graph().node()) { - if (node.op_type() == "SoftmaxCrossEntropyLoss") { - return true; - } - } - return false; - }; - ASSERT_EQ(softmaxceloss_node_found(eval_model), true); - ASSERT_EQ(softmaxceloss_node_found(inference_model), false); - - RunInferenceSession(*env, inference_model_path); -} - void TestModuleExportWithExternalData(const std::vector>& providers) { auto training_model_uri = MODEL_FOLDER "training_model.onnx"; auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; @@ -568,11 +495,6 @@ TEST(TrainingApiTest, ModuleExportModelForInferencingCPU) { TestModuleExport(providers); } -TEST(TrainingApiTest, ModuleFromBufferExportModelForInferencingCPU) { - std::vector> providers{onnxruntime::test::DefaultCpuExecutionProvider()}; - TestModuleExportFromBuffer(providers); -} - TEST(TrainingApiTest, ModuleExportModelForInferencingCPU_WithExternalData) { std::vector> providers{onnxruntime::test::DefaultCpuExecutionProvider()}; TestModuleExportWithExternalData(providers); diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index 8f25e1e4c92b8..04d13912e73d8 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -265,6 +265,75 @@ TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) { train_model_data); } +TEST(TrainingCApiTest, LoadONNXModelsFromBufferThenExport) { + auto model_path = MODEL_FOLDER "training_model.onnx"; + size_t model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len)); + std::vector train_model_data(model_data_len); + std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len); + ASSERT_TRUE(train_model_data.size() == model_data_len); + + auto eval_model_path = MODEL_FOLDER "eval_model.onnx"; + size_t eval_model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, eval_model_data_len)); + std::vector eval_model_data(eval_model_data_len); + std::ifstream eval_bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary); + eval_bytes_stream.read(reinterpret_cast(eval_model_data.data()), eval_model_data_len); + ASSERT_TRUE(eval_model_data.size() == eval_model_data_len); + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_data, + eval_model_data); + + // randomly selected output name + std::vector graph_output_names({"onnx::loss::21273"}); + training_session.ExportModelForInferencing(MODEL_FOLDER "inference_model.onnx", graph_output_names); + + // Check that the model is a valid inference model by loading into an InferenceSession + std::unique_ptr environment; + ASSERT_STATUS_OK(Environment::Create(nullptr, environment)); + InferenceSession inference_session = InferenceSession(SessionOptions(), *environment, MODEL_FOLDER "inference_model.onnx"); + + // Check that you can no longer train or evaluate after exporting. Since passing incorrect inputs will also cause + // TrainStep and EvalStep to throw errors, we check for the error message. + ORT_TRY { + training_session.TrainStep({}); + FAIL() << "TrainStep after exporting for inference should have thrown an error."; + } + ORT_CATCH(const Ort::Exception& e) { + ORT_HANDLE_EXCEPTION([&e]() { + ASSERT_THAT(e.what(), + testing::HasSubstr("Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.")); + }); + } + ORT_CATCH(...) { + FAIL() << "TrainStep after exporting for inference should have thrown an Ort::Exception."; + } + + ORT_TRY { + training_session.EvalStep({}); + FAIL() << "EvalStep after exporting for inference should have thrown an Ort::Exception."; + } + ORT_CATCH(const Ort::Exception& e) { + ORT_HANDLE_EXCEPTION([&e]() { + ASSERT_THAT(e.what(), + testing::HasSubstr("Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.")); + }); + } + ORT_CATCH(...) { + FAIL() << "EvalStep after exporting for inference should have thrown an Ort::Exception."; + } + + // attempt to retrieve the input & output names of the eval model + ASSERT_THROW(training_session.InputNames(false), Ort::Exception); + ASSERT_THROW(training_session.OutputNames(false), Ort::Exception); +} + TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) { auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort"; auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort"; diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 56c46c2869564..57843c689f9b5 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -65,20 +65,6 @@ Status RemoveUnusedNodes(Graph& inference_graph, InlinedVector& return Status::OK(); } -//TODO: REMOVE THIS METHOD BEFORE YOUR PR ITS JUST FOR DEBUGGING PURPOSES -Status RemoveThisMethodBeforeYourPR(Graph& inference_graph) { - GraphViewer graph_viewer(inference_graph); - const auto node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (size_t idx = node_indices.size(); idx > 0; --idx) { - const NodeIndex node_index = idx - 1; - auto* node = inference_graph.GetNode(node_index); - if (node->Name().empty()) { - inference_graph.RemoveNode(node_index); - } - } - - return Status::OK(); -} Status TransformModelOutputsForInference(Graph& inference_graph, gsl::span inference_graph_outputs) { // Model is updated to remove any outputs that are not defined in inference_graph_outputs. Nodes @@ -449,7 +435,8 @@ size_t Module::GetTrainingModelOutputCount() const noexcept { return train_output_names_.size(); } -size_t Module::GetEvalModelOutputCount() const noexcept { +size_t Module::GetEvalModelOutputCount() const { + ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. "); return eval_output_names_.size(); } @@ -459,6 +446,7 @@ std::string Module::GetTrainingModelOutputName(size_t index) const { } std::string Module::GetEvalModelOutputName(size_t index) const { + ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. "); ORT_ENFORCE(index < eval_output_names_.size(), "Eval output name index out of range. Expected in range [0-", eval_output_names_.size(), "). Actual: ", index); return eval_output_names_.at(index); @@ -682,6 +670,7 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path gsl::span graph_output_names) { ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, "Cannot export the model with a nominal state. Please load the model parameters first."); + ORT_RETURN_IF(!eval_sess_, "Eval model was not provided. Cannot export a model for inferencing."); // Once finished_training is set to true, will no longer be able to train or evaluate with this module // since the eval session graph will have been modified. @@ -719,7 +708,8 @@ size_t Module::GetTrainingModelInputCount() const noexcept { return train_input_names_.UserInputNames().size(); } -size_t Module::GetEvalModelInputCount() const noexcept { +size_t Module::GetEvalModelInputCount() const { + ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. "); return eval_user_input_count_; } @@ -731,6 +721,7 @@ std::string Module::GetTrainingModelInputName(size_t index) const { } std::string Module::GetEvalModelInputName(size_t index) const { + ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input name. "); ORT_ENFORCE(index < eval_user_input_count_, "Eval input name index out of range. Expected in range [0-", eval_user_input_count_, "). Actual: ", index); @@ -741,7 +732,8 @@ std::pair Module::GetTrainingModelInputs() return train_sess_->GetModelInputs(); } -std::pair Module::GetEvalModelInputs() const noexcept { +std::pair Module::GetEvalModelInputs() const { + ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. "); return eval_sess_->GetModelInputs(); } diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 51b469bae03d1..e4a784b95bcfd 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -116,7 +116,7 @@ struct Module { size_t GetTrainingModelOutputCount() const noexcept; // Returns the output count for eval graph - size_t GetEvalModelOutputCount() const noexcept; + size_t GetEvalModelOutputCount() const; // Returns the output names for train graph std::string GetTrainingModelOutputName(size_t index) const; @@ -151,7 +151,7 @@ struct Module { size_t GetTrainingModelInputCount() const noexcept; // Returns the user input count for eval graph - size_t GetEvalModelInputCount() const noexcept; + size_t GetEvalModelInputCount() const; // Returns the user input name for train graph at given index std::string GetTrainingModelInputName(size_t index) const; @@ -163,7 +163,7 @@ struct Module { std::pair GetTrainingModelInputs() const noexcept; // Returns the input definitions of the Eval model - std::pair GetEvalModelInputs() const noexcept; + std::pair GetEvalModelInputs() const; private: std::unique_ptr train_sess_{nullptr}; diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index 78619947b8b18..a5a856b5fff32 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -37,7 +37,7 @@ size_t TrainingSession::GetTrainingModelOutputCount() const noexcept { return module_->GetTrainingModelOutputCount(); } -size_t TrainingSession::GetEvalModelOutputCount() const noexcept { +size_t TrainingSession::GetEvalModelOutputCount() const { return module_->GetEvalModelOutputCount(); } @@ -45,7 +45,7 @@ std::string TrainingSession::GetTrainingModelOutputName(size_t index) const noex return module_->GetTrainingModelOutputName(index); } -std::string TrainingSession::GetEvalModelOutputName(size_t index) const noexcept { +std::string TrainingSession::GetEvalModelOutputName(size_t index) const { return module_->GetEvalModelOutputName(index); } @@ -53,7 +53,7 @@ size_t TrainingSession::GetTrainingModelInputCount() const noexcept { return module_->GetTrainingModelInputCount(); } -size_t TrainingSession::GetEvalModelInputCount() const noexcept { +size_t TrainingSession::GetEvalModelInputCount() const { return module_->GetEvalModelInputCount(); } @@ -61,7 +61,7 @@ std::string TrainingSession::GetTrainingModelInputName(size_t index) const noexc return module_->GetTrainingModelInputName(index); } -std::string TrainingSession::GetEvalModelInputName(size_t index) const noexcept { +std::string TrainingSession::GetEvalModelInputName(size_t index) const { return module_->GetEvalModelInputName(index); } diff --git a/orttraining/orttraining/training_api/training_session.h b/orttraining/orttraining/training_api/training_session.h index 13b0ae79093de..d2552891a76a4 100644 --- a/orttraining/orttraining/training_api/training_session.h +++ b/orttraining/orttraining/training_api/training_session.h @@ -30,19 +30,19 @@ class TrainingSession { size_t GetTrainingModelOutputCount() const noexcept; - size_t GetEvalModelOutputCount() const noexcept; + size_t GetEvalModelOutputCount() const; std::string GetTrainingModelOutputName(size_t index) const noexcept; - std::string GetEvalModelOutputName(size_t index) const noexcept; + std::string GetEvalModelOutputName(size_t index) const; size_t GetTrainingModelInputCount() const noexcept; - size_t GetEvalModelInputCount() const noexcept; + size_t GetEvalModelInputCount() const; std::string GetTrainingModelInputName(size_t index) const noexcept; - std::string GetEvalModelInputName(size_t index) const noexcept; + std::string GetEvalModelInputName(size_t index) const; Status TrainStep(const RunOptions& run_options, const std::vector& inputs, From ec3d37d82555135f2cee73d7d4d6da8f29955c07 Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 19 Jul 2024 16:53:21 -0700 Subject: [PATCH 05/12] removed the eval model path parameter + more cleanup --- orttraining/orttraining/training_api/module.cc | 16 ++-------------- orttraining/orttraining/training_api/module.h | 1 - 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 57843c689f9b5..43183a8fe8245 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -86,7 +86,7 @@ Status TransformModelOutputsForInference(Graph& inference_graph, inference_graph.SetOutputs(inference_graph_output_node_args); ORT_RETURN_IF_ERROR(RemoveUnusedNodes(inference_graph, inference_graph_output_node_args)); - ORT_THROW_IF_ERROR(inference_graph.Resolve()); + ORT_RETURN_IF_ERROR(inference_graph.Resolve()); return Status::OK(); } @@ -114,7 +114,7 @@ Status TransformModelInputsForInference(Graph& inference_graph, } inference_graph.SetInputs(user_graph_inputs); - ORT_THROW_IF_ERROR(inference_graph.Resolve()); + ORT_RETURN_IF_ERROR(inference_graph.Resolve()); return Status::OK(); } @@ -388,11 +388,6 @@ Module::Module(const ModelIdentifiers& model_identifiers, } ORT_THROW_IF_ERROR(eval_sess_->Initialize()); utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); - // TODO: remove this - // std::shared_ptr inference_model = eval_sess_->GetModel(); - // Graph& inference_graph = inference_model->MainGraph(); - - // ORT_THROW_IF_ERROR(RemoveThisMethodBeforeYourPR(inference_graph)); // Eval model validation // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval @@ -417,14 +412,8 @@ Module::Module(const ModelIdentifiers& model_identifiers, eval_user_input_count_ = eval_user_input_names.size(); eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); - // TODO: remove this - // ORT_THROW_IF_ERROR(RemoveThisMethodBeforeYourPR(inference_graph)); // Keep a copy of the eval model path to be able to later export the model for inferencing. // The inference model will be reconstructed from the eval model. - // TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer. - if (std::holds_alternative>(model_identifiers.eval_model)) { - eval_model_path_ = std::get>(model_identifiers.eval_model); - } } Module::~Module() { @@ -676,7 +665,6 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path // since the eval session graph will have been modified. finished_training_ = true; - // Model& inference_model = const_cast(eval_sess_->GetModel()); std::shared_ptr inference_model = eval_sess_->GetModel(); Graph& inference_graph = inference_model->MainGraph(); diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index e4a784b95bcfd..237822738782e 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -199,7 +199,6 @@ struct Module { CheckpointState* state_; // Non owning pointer to the state. bool accumulate_gradient_ = false; - std::optional eval_model_path_; size_t eval_user_input_count_{0U}; }; From 985a5816ae3bd8c7ef23d343100e3b4335fe3760 Mon Sep 17 00:00:00 2001 From: carzh Date: Mon, 22 Jul 2024 10:40:53 -0700 Subject: [PATCH 06/12] lint --- .../orttraining/training_api/module.cc | 27 +++++++++++++------ orttraining/orttraining/training_api/module.h | 1 + 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 43183a8fe8245..8e4fabced730c 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -3,6 +3,8 @@ #include "orttraining/training_api/module.h" +#include + #include "core/common/safeint.h" #include "core/common/string_utils.h" #include "core/framework/execution_provider.h" @@ -425,7 +427,8 @@ size_t Module::GetTrainingModelOutputCount() const noexcept { } size_t Module::GetEvalModelOutputCount() const { - ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. "); + ORT_ENFORCE(!finished_training_, + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. "); return eval_output_names_.size(); } @@ -435,7 +438,8 @@ std::string Module::GetTrainingModelOutputName(size_t index) const { } std::string Module::GetEvalModelOutputName(size_t index) const { - ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. "); + ORT_ENFORCE(!finished_training_, + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. "); ORT_ENFORCE(index < eval_output_names_.size(), "Eval output name index out of range. Expected in range [0-", eval_output_names_.size(), "). Actual: ", index); return eval_output_names_.at(index); @@ -611,7 +615,8 @@ Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool tr } Status Module::LazyResetGrad() { - ORT_RETURN_IF(finished_training_, "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); + ORT_RETURN_IF(finished_training_, + "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); accumulate_gradient_ = false; return Status::OK(); } @@ -619,7 +624,8 @@ Status Module::LazyResetGrad() { Status Module::TrainStep(const std::vector& inputs, std::vector& outputs) { ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, "Cannot perform TrainStep with a nominal state. Please load the model parameters first."); - ORT_RETURN_IF(finished_training_, "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); + ORT_RETURN_IF(finished_training_, + "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); std::vector> params; std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); @@ -642,7 +648,8 @@ Status Module::TrainStep(const std::vector& inputs, std::vector& inputs, std::vector& outputs) { ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, "Cannot perform EvalStep with a nominal state. Please load the model parameters first."); - ORT_RETURN_IF(finished_training_, "Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); + ORT_RETURN_IF(finished_training_, + "Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized."); std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); @@ -697,13 +704,16 @@ size_t Module::GetTrainingModelInputCount() const noexcept { } size_t Module::GetEvalModelInputCount() const { - ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. "); + ORT_ENFORCE(!finished_training_, + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. "); return eval_user_input_count_; } std::string Module::GetTrainingModelInputName(size_t index) const { ORT_ENFORCE(index < train_input_names_.UserInputNames().size(), - "Train input name index out of range. Expected in range [0-", train_input_names_.UserInputNames().size(), "). Actual: ", + "Train input name index out of range. Expected in range [0-", + train_input_names_.UserInputNames().size(), + "). Actual: ", index); return train_input_names_.UserInputNames()[index]; } @@ -721,7 +731,8 @@ std::pair Module::GetTrainingModelInputs() } std::pair Module::GetEvalModelInputs() const { - ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. "); + ORT_ENFORCE(!finished_training_, + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. "); return eval_sess_->GetModelInputs(); } diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 237822738782e..13b08beef64ab 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/session/inference_session.h" #include "orttraining/training_api/utils.h" From 26dc4ec596014695ffe24aa646d674fd09ffcd7c Mon Sep 17 00:00:00 2001 From: carzh Date: Mon, 22 Jul 2024 10:43:12 -0700 Subject: [PATCH 07/12] lintrunner --- .../orttraining/training_api/module.cc | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 8e4fabced730c..cba9c35ac34d1 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -428,7 +428,7 @@ size_t Module::GetTrainingModelOutputCount() const noexcept { size_t Module::GetEvalModelOutputCount() const { ORT_ENFORCE(!finished_training_, - "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. "); + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. "); return eval_output_names_.size(); } @@ -439,7 +439,7 @@ std::string Module::GetTrainingModelOutputName(size_t index) const { std::string Module::GetEvalModelOutputName(size_t index) const { ORT_ENFORCE(!finished_training_, - "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. "); + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. "); ORT_ENFORCE(index < eval_output_names_.size(), "Eval output name index out of range. Expected in range [0-", eval_output_names_.size(), "). Actual: ", index); return eval_output_names_.at(index); @@ -616,7 +616,8 @@ Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool tr Status Module::LazyResetGrad() { ORT_RETURN_IF(finished_training_, - "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); + "Cannot train after exporting for inferencing. ", + "To continue training from this point, please save the checkpoint and create a new TrainingSession."); accumulate_gradient_ = false; return Status::OK(); } @@ -625,7 +626,8 @@ Status Module::TrainStep(const std::vector& inputs, std::vectormodule_checkpoint_state.is_nominal_state, "Cannot perform TrainStep with a nominal state. Please load the model parameters first."); ORT_RETURN_IF(finished_training_, - "Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); + "Cannot train after exporting for inferencing. ", + "To continue training from this point, please save the checkpoint and create a new TrainingSession."); std::vector> params; std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); @@ -649,7 +651,8 @@ Status Module::EvalStep(const std::vector& inputs, std::vectormodule_checkpoint_state.is_nominal_state, "Cannot perform EvalStep with a nominal state. Please load the model parameters first."); ORT_RETURN_IF(finished_training_, - "Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."); + "Cannot evaluate after exporting for inferencing. ", + "To continue training from this point, please save the checkpoint and create a new TrainingSession."); ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized."); std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); @@ -705,7 +708,7 @@ size_t Module::GetTrainingModelInputCount() const noexcept { size_t Module::GetEvalModelInputCount() const { ORT_ENFORCE(!finished_training_, - "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. "); + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. "); return eval_user_input_count_; } @@ -719,7 +722,8 @@ std::string Module::GetTrainingModelInputName(size_t index) const { } std::string Module::GetEvalModelInputName(size_t index) const { - ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input name. "); + ORT_ENFORCE(!finished_training_, + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input name. "); ORT_ENFORCE(index < eval_user_input_count_, "Eval input name index out of range. Expected in range [0-", eval_user_input_count_, "). Actual: ", index); @@ -732,7 +736,7 @@ std::pair Module::GetTrainingModelInputs() std::pair Module::GetEvalModelInputs() const { ORT_ENFORCE(!finished_training_, - "Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. "); + "Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. "); return eval_sess_->GetModelInputs(); } From 702e8b96d4df18da11e09e21479dfac3e35b19ec Mon Sep 17 00:00:00 2001 From: carzh Date: Mon, 22 Jul 2024 11:50:16 -0700 Subject: [PATCH 08/12] updated docs --- .../Training/TrainingSession.shared.cs | 6 ++++-- java/src/main/java/ai/onnxruntime/OrtTrainingSession.java | 5 +++-- objectivec/include/ort_training_session.h | 7 ++++--- orttraining/orttraining/python/training/api/module.py | 5 ++++- .../training_api/include/onnxruntime_training_c_api.h | 5 +++-- .../training_api/include/onnxruntime_training_cxx_api.h | 5 +++-- 6 files changed, 21 insertions(+), 12 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index fec0d46e96dfb..e7969b21341b6 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -568,8 +568,10 @@ public void OptimizerStep(RunOptions options) /// an inference model if it knows the inference graph outputs. The input inference graph outputs /// are used to prune the eval model so that the inference model's outputs align with the provided outputs. /// The exported model is saved at the path provided and can be used for inferencing with InferenceSession. - /// Note that the function re-loads the eval model from the path provided to TrainingSession - /// and expects that this path still be valid. + /// + /// This function modifies the eval graph in-place, so after this method is called, the TrainingSession can + /// no longer be used for training. In order to continue training from this point, save the checkpoint state + /// and create a new TrainingSession with the saved checkpoint state. /// /// Path where the inference model should be serialized to. /// Names of the outputs that are needed in the inference model. diff --git a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java index eeede3a1bed0b..18786bdcc27b4 100644 --- a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java @@ -998,8 +998,9 @@ private native void schedulerStep(long apiHandle, long trainingApiHandle, long n * Exports the evaluation model as a model suitable for inference, setting the desired nodes as * output nodes. * - *

Note that this method reloads the evaluation model from the path provided to the training - * session, and this path must still be valid. + *

Note that this method modifies the eval session in-place; thus, after this method is called, the + * OrtTrainingSession can no longer be trained with. To continue training from this point, save the checkpoint + * and then load it into a new OrtTrainingSession. * * @param outputPath The path to write out the inference model. * @param outputNames The names of the output nodes. diff --git a/objectivec/include/ort_training_session.h b/objectivec/include/ort_training_session.h index 2ad4fed93c331..6b1b3a52343e3 100644 --- a/objectivec/include/ort_training_session.h +++ b/objectivec/include/ort_training_session.h @@ -229,10 +229,11 @@ NS_ASSUME_NONNULL_BEGIN * * If the training session was provided with an eval model, the training session can generate an inference model if it * knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the - * inference model's outputs align with the provided outputs. The exported model is saved at the path provided and - * can be used for inferencing with `ORTSession`. + * inference model's outputs align with the provided outputs. * - * @note The method reloads the eval model from the path provided to the initializer and expects this path to be valid. + * @note This method modifies the eval model graph in-place, so after this method is called, the ORTTrainingSession + * can no longer be used for training. To resume training from this point, save the checkpoint state and create a new + * ORTTrainingSession with the saved checkpoint state. * * @param inferenceModelPath The path to the serialized the inference model. * @param graphOutputNames The names of the outputs that are needed in the inference model. diff --git a/orttraining/orttraining/python/training/api/module.py b/orttraining/orttraining/python/training/api/module.py index a87cd6fdd93cf..2f4ec0e4c3028 100644 --- a/orttraining/orttraining/python/training/api/module.py +++ b/orttraining/orttraining/python/training/api/module.py @@ -194,9 +194,12 @@ def export_model_for_inferencing( Once training is complete, this function can be used to drop the training specific nodes in the onnx model. In particular, this function does the following: - - Parse over the training graph and identify nodes that generate the given output names. + - Parse over the eval graph and identify nodes that generate the given output names. - Drop all subsequent nodes in the graph since they are not relevant to the inference graph. + Once this method is called, training is considered complete and the module can no longer be used for training. + To resume training from this point, save the checkpoint and create a new module from the checkpoint. + Args: inference_model_uri: The path to the inference model. graph_output_names: The list of output names that are required for inferencing. diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index ed6d151a595b4..0efbaa419bb83 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -513,8 +513,9 @@ struct OrtTrainingApi { * an inference model if it knows the inference graph outputs. The input inference graph outputs * are used to prune the eval model so that the inference model's outputs align with the provided outputs. * The exported model is saved at the path provided and can be used for inferencing with InferenceSession. - * \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession - * and expects that this path still be valid. + * \note Note that the function modifies the eval model graph in-place, so after this method is called, the + * OrtTrainingSession can no longer be used for training. To resume training from this point, save the checkpoint + * state and create a new OrtTrainingSession with the updated eval model. * * \param[in] sess The `this` pointer to the training session. * \param[in] inference_model_path Path where the inference model should be serialized to. diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index e78c16136ab3f..c9a7627140d4b 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -336,8 +336,9 @@ class TrainingSession : public detail::Base { * an inference model if it knows the inference graph outputs. The input inference graph outputs * are used to prune the eval model so that the inference model's outputs align with the provided outputs. * The exported model is saved at the path provided and can be used for inferencing with Ort::Session. - * \note Note that the function re-loads the eval model from the path provided to Ort::TrainingSession - * and expects that this path still be valid. + * \note Note that the function modifies the eval model graph in-place, so after this method is called, the + * OrtTrainingSession can no longer be used for training. To resume training from this point, save the checkpoint + * state and create a new OrtTrainingSession with the updated eval model. * * \param[in] inference_model_path Path where the inference model should be serialized to. * \param[in] graph_output_names Names of the outputs that are needed in the inference model. From a01a96bac1c3a6cbcf0e235c4c18e5d410e0726a Mon Sep 17 00:00:00 2001 From: carzh Date: Tue, 23 Jul 2024 14:02:47 -0700 Subject: [PATCH 09/12] gradle spotlessapply --- java/src/main/java/ai/onnxruntime/OrtTrainingSession.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java index 18786bdcc27b4..7249f3d052013 100644 --- a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java @@ -998,9 +998,9 @@ private native void schedulerStep(long apiHandle, long trainingApiHandle, long n * Exports the evaluation model as a model suitable for inference, setting the desired nodes as * output nodes. * - *

Note that this method modifies the eval session in-place; thus, after this method is called, the - * OrtTrainingSession can no longer be trained with. To continue training from this point, save the checkpoint - * and then load it into a new OrtTrainingSession. + *

Note that this method modifies the eval session in-place; thus, after this method is called, + * the OrtTrainingSession can no longer be trained with. To continue training from this point, + * save the checkpoint and then load it into a new OrtTrainingSession. * * @param outputPath The path to write out the inference model. * @param outputNames The names of the output nodes. From aa97469cc87f2c20fd6389056481de31929738ec Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 25 Jul 2024 10:54:27 -0700 Subject: [PATCH 10/12] switched from method in inferencesession api to using a wrapper class instead --- onnxruntime/core/session/inference_session.h | 4 ---- orttraining/orttraining/training_api/module.cc | 10 ++++++---- orttraining/orttraining/training_api/module.h | 13 +++++++++++++ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 64e3e35cb7f2b..e1cd085d2c271 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -403,10 +403,6 @@ class InferenceSession { int32_t partial_graph_index); #endif -#ifdef ENABLE_TRAINING_APIS - std::shared_ptr GetModel() noexcept { return model_; }; -#endif - /** * @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK. * @note lifetime of the returned pointer is valid as long as the Session object is live. diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index cba9c35ac34d1..3293fcea1fd65 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -675,8 +675,10 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path // since the eval session graph will have been modified. finished_training_ = true; - std::shared_ptr inference_model = eval_sess_->GetModel(); - Graph& inference_graph = inference_model->MainGraph(); + EvalSessionWrapper& eval_sess_wrapper = static_cast(*eval_sess_); + + Model& inference_model = eval_sess_wrapper.GetMutableModel(); + Graph& inference_graph = eval_sess_wrapper.GetMutableGraph(); // The cloned model's outputs are transformed such that the model has outputs as defined by graph_output_names // Any nodes not contributing to the inference outputs will be pruned. @@ -693,9 +695,9 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(ExternalCheckpointDataPath(ToPathString(inference_model_path))); PathString inference_model_pathstring = ToPathString(inference_model_path); ORT_THROW_IF_ERROR( - Model::SaveWithExternalInitializers(*inference_model, inference_model_pathstring, external_data_name, 64)); + Model::SaveWithExternalInitializers(inference_model, inference_model_pathstring, external_data_name, 64)); } else { - ORT_THROW_IF_ERROR(Model::Save(*inference_model, ToPathString(inference_model_path))); + ORT_THROW_IF_ERROR(Model::Save(inference_model, ToPathString(inference_model_path))); } // Save the model at the desired location. return Status::OK(); diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 13b08beef64ab..3ce4aeaaedd60 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -12,6 +12,19 @@ namespace onnxruntime { namespace training { namespace api { +class EvalSessionWrapper : public InferenceSession { + public: + using InferenceSession::InferenceSession; + + Graph& GetMutableGraph() const { + return model_->MainGraph(); + } + + Model& GetMutableModel() { + return *model_; + } +}; + struct Parameter { public: Parameter(const std::string& name, const OrtValue& data, const bool requires_grad) From e66740dfa6ddc024e66568c9b34b5fb63f6e54de Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 25 Jul 2024 15:44:09 -0700 Subject: [PATCH 11/12] moved wrapper class from header file to locally defined in export function --- orttraining/orttraining/training_api/module.cc | 13 +++++++++++++ orttraining/orttraining/training_api/module.h | 13 ------------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 3293fcea1fd65..68e6cb8ae04a7 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -671,6 +671,19 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path "Cannot export the model with a nominal state. Please load the model parameters first."); ORT_RETURN_IF(!eval_sess_, "Eval model was not provided. Cannot export a model for inferencing."); + class EvalSessionWrapper : public InferenceSession { + public: + using InferenceSession::InferenceSession; + + Graph& GetMutableGraph() const { + return model_->MainGraph(); + } + + Model& GetMutableModel() { + return *model_; + } + }; + // Once finished_training is set to true, will no longer be able to train or evaluate with this module // since the eval session graph will have been modified. finished_training_ = true; diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 3ce4aeaaedd60..13b08beef64ab 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -12,19 +12,6 @@ namespace onnxruntime { namespace training { namespace api { -class EvalSessionWrapper : public InferenceSession { - public: - using InferenceSession::InferenceSession; - - Graph& GetMutableGraph() const { - return model_->MainGraph(); - } - - Model& GetMutableModel() { - return *model_; - } -}; - struct Parameter { public: Parameter(const std::string& name, const OrtValue& data, const bool requires_grad) From 18fe5a4f939468acc6f64e8fbf54b286cb36e484 Mon Sep 17 00:00:00 2001 From: carzh Date: Mon, 29 Jul 2024 20:59:28 +0000 Subject: [PATCH 12/12] updated running the inference session check to use the same execution providers --- .../test/training_api/core/training_api_tests.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index 90c97eed0c6d3..52205abf38193 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -66,8 +66,12 @@ Status CreateFakeOptimizerCheckpointStateOnCPU( return Status::OK(); } -void RunInferenceSession(const Environment& env, const PathString& inference_model_path) { +void RunInferenceSession(const Environment& env, const PathString& inference_model_path, const std::vector>& providers) { auto inference_session = std::make_unique(onnxruntime::SessionOptions(), env); + + for (const std::shared_ptr& p_exec_provider : providers) { + ASSERT_STATUS_OK(inference_session->RegisterExecutionProvider(p_exec_provider)); + } ASSERT_STATUS_OK(inference_session->Load(inference_model_path)); ASSERT_STATUS_OK(inference_session->Initialize()); @@ -135,7 +139,7 @@ void TestModuleExport(const std::vector>& pr ASSERT_EQ(softmaxceloss_node_found(eval_model), true); ASSERT_EQ(softmaxceloss_node_found(inference_model), false); - RunInferenceSession(*env, inference_model_path); + RunInferenceSession(*env, inference_model_path, providers); } void TestModuleExportWithExternalData(const std::vector>& providers) { @@ -183,7 +187,7 @@ void TestModuleExportWithExternalData(const std::vector