diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py index d035e2087031b..3dee3a4169942 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py @@ -16,7 +16,7 @@ from orttraining_test_ort_apis_onnxblock import _get_models import onnxruntime.training.onnxblock as onnxblock -from onnxruntime import OrtValue, SessionOptions, InferenceSession +from onnxruntime import InferenceSession, OrtValue, SessionOptions from onnxruntime.training import artifacts from onnxruntime.training.api import CheckpointState, LinearLRScheduler, Module, Optimizer diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index e9729558fcd35..90c97eed0c6d3 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -32,19 +32,6 @@ namespace { constexpr int64_t TOTAL_STEP_COUNT = 100; constexpr float INITIAL_LR = 1e-3f; -std::vector ReadFileIntoBuffer(const std::string& file_path) { - size_t num_bytes = 0; - ASSERT_STATUS_OK(Env::Default().GetFileLength(file_path.c_str(), num_bytes)); - std::vector buffer(num_bytes); - - std::ifstream bytes_stream(file_path, std::ifstream::in | std::ifstream::binary); - bytes_stream.read(reinterpret_cast(buffer.data()), num_bytes); - - ASSERT_TRUE(bytes_stream); - - return buffer; -} - /** * @brief Create a Fake Optimizer Checkpoint State On CPU. * @@ -151,66 +138,6 @@ void TestModuleExport(const std::vector>& pr RunInferenceSession(*env, inference_model_path); } -void TestModuleExportFromBuffer(const std::vector>& providers) { - auto training_model_uri = MODEL_FOLDER "training_model.onnx"; - auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; - - onnxruntime::training::api::CheckpointState state; - auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt"; - // Load checkpoint, eval model, and training model into buffers - std::vector checkpoint_bytes = ReadFileIntoBuffer(checkpoint_to_load_path); - std::vector training_model_bytes = ReadFileIntoBuffer(training_model_uri); - std::vector eval_model_bytes = ReadFileIntoBuffer(eval_model_uri); - - ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpointFromBuffer(checkpoint_bytes, state)); - - // load training and eval model into buffers - std::unique_ptr env; - ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model_identifier = ModelIdentifiers(training_model_bytes, - std::optional>(eval_model_bytes), - std::nullopt); - auto model = std::make_unique( - model_identifier, &state, onnxruntime::SessionOptions(), - *env, providers); - - auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir"); - if (Env::Default().FolderExists(test_dir)) { - ORT_ENFORCE(Env::Default().DeleteFolder(test_dir).IsOK()); - } - onnxruntime::test::TemporaryDirectory tmp_dir{test_dir}; - PathString inference_model_path{ - ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("inference_model.onnx"))}; - - std::vector graph_output_names({"output-0"}); - ASSERT_STATUS_OK(model->ExportModelForInferencing(ToUTF8String(inference_model_path), graph_output_names)); - - // Load model - ONNX_NAMESPACE::ModelProto eval_model; - ONNX_NAMESPACE::ModelProto inference_model; - ORT_THROW_IF_ERROR(Model::Load(eval_model_uri, eval_model)); - ORT_THROW_IF_ERROR(Model::Load(inference_model_path, inference_model)); - - // Check it has only one graph input - ASSERT_EQ(eval_model.graph().input().size(), 6); - ASSERT_EQ(inference_model.graph().input().size(), 1); - ASSERT_EQ(inference_model.graph().input()[0].name(), "input-0"); - - // Check that it does not have any node which has op type SoftmaxCrossEntropyLoss - auto softmaxceloss_node_found = [](auto& model) -> bool { - for (auto& node : model.graph().node()) { - if (node.op_type() == "SoftmaxCrossEntropyLoss") { - return true; - } - } - return false; - }; - ASSERT_EQ(softmaxceloss_node_found(eval_model), true); - ASSERT_EQ(softmaxceloss_node_found(inference_model), false); - - RunInferenceSession(*env, inference_model_path); -} - void TestModuleExportWithExternalData(const std::vector>& providers) { auto training_model_uri = MODEL_FOLDER "training_model.onnx"; auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; @@ -568,11 +495,6 @@ TEST(TrainingApiTest, ModuleExportModelForInferencingCPU) { TestModuleExport(providers); } -TEST(TrainingApiTest, ModuleFromBufferExportModelForInferencingCPU) { - std::vector> providers{onnxruntime::test::DefaultCpuExecutionProvider()}; - TestModuleExportFromBuffer(providers); -} - TEST(TrainingApiTest, ModuleExportModelForInferencingCPU_WithExternalData) { std::vector> providers{onnxruntime::test::DefaultCpuExecutionProvider()}; TestModuleExportWithExternalData(providers); diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index 8f25e1e4c92b8..04d13912e73d8 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -265,6 +265,75 @@ TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) { train_model_data); } +TEST(TrainingCApiTest, LoadONNXModelsFromBufferThenExport) { + auto model_path = MODEL_FOLDER "training_model.onnx"; + size_t model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len)); + std::vector train_model_data(model_data_len); + std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len); + ASSERT_TRUE(train_model_data.size() == model_data_len); + + auto eval_model_path = MODEL_FOLDER "eval_model.onnx"; + size_t eval_model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, eval_model_data_len)); + std::vector eval_model_data(eval_model_data_len); + std::ifstream eval_bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary); + eval_bytes_stream.read(reinterpret_cast(eval_model_data.data()), eval_model_data_len); + ASSERT_TRUE(eval_model_data.size() == eval_model_data_len); + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_data, + eval_model_data); + + // randomly selected output name + std::vector graph_output_names({"onnx::loss::21273"}); + training_session.ExportModelForInferencing(MODEL_FOLDER "inference_model.onnx", graph_output_names); + + // Check that the model is a valid inference model by loading into an InferenceSession + std::unique_ptr environment; + ASSERT_STATUS_OK(Environment::Create(nullptr, environment)); + InferenceSession inference_session = InferenceSession(SessionOptions(), *environment, MODEL_FOLDER "inference_model.onnx"); + + // Check that you can no longer train or evaluate after exporting. Since passing incorrect inputs will also cause + // TrainStep and EvalStep to throw errors, we check for the error message. + ORT_TRY { + training_session.TrainStep({}); + FAIL() << "TrainStep after exporting for inference should have thrown an error."; + } + ORT_CATCH(const Ort::Exception& e) { + ORT_HANDLE_EXCEPTION([&e]() { + ASSERT_THAT(e.what(), + testing::HasSubstr("Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.")); + }); + } + ORT_CATCH(...) { + FAIL() << "TrainStep after exporting for inference should have thrown an Ort::Exception."; + } + + ORT_TRY { + training_session.EvalStep({}); + FAIL() << "EvalStep after exporting for inference should have thrown an Ort::Exception."; + } + ORT_CATCH(const Ort::Exception& e) { + ORT_HANDLE_EXCEPTION([&e]() { + ASSERT_THAT(e.what(), + testing::HasSubstr("Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession.")); + }); + } + ORT_CATCH(...) { + FAIL() << "EvalStep after exporting for inference should have thrown an Ort::Exception."; + } + + // attempt to retrieve the input & output names of the eval model + ASSERT_THROW(training_session.InputNames(false), Ort::Exception); + ASSERT_THROW(training_session.OutputNames(false), Ort::Exception); +} + TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) { auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort"; auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort"; diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 56c46c2869564..57843c689f9b5 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -65,20 +65,6 @@ Status RemoveUnusedNodes(Graph& inference_graph, InlinedVector& return Status::OK(); } -//TODO: REMOVE THIS METHOD BEFORE YOUR PR ITS JUST FOR DEBUGGING PURPOSES -Status RemoveThisMethodBeforeYourPR(Graph& inference_graph) { - GraphViewer graph_viewer(inference_graph); - const auto node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (size_t idx = node_indices.size(); idx > 0; --idx) { - const NodeIndex node_index = idx - 1; - auto* node = inference_graph.GetNode(node_index); - if (node->Name().empty()) { - inference_graph.RemoveNode(node_index); - } - } - - return Status::OK(); -} Status TransformModelOutputsForInference(Graph& inference_graph, gsl::span inference_graph_outputs) { // Model is updated to remove any outputs that are not defined in inference_graph_outputs. Nodes @@ -449,7 +435,8 @@ size_t Module::GetTrainingModelOutputCount() const noexcept { return train_output_names_.size(); } -size_t Module::GetEvalModelOutputCount() const noexcept { +size_t Module::GetEvalModelOutputCount() const { + ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. "); return eval_output_names_.size(); } @@ -459,6 +446,7 @@ std::string Module::GetTrainingModelOutputName(size_t index) const { } std::string Module::GetEvalModelOutputName(size_t index) const { + ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. "); ORT_ENFORCE(index < eval_output_names_.size(), "Eval output name index out of range. Expected in range [0-", eval_output_names_.size(), "). Actual: ", index); return eval_output_names_.at(index); @@ -682,6 +670,7 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path gsl::span graph_output_names) { ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, "Cannot export the model with a nominal state. Please load the model parameters first."); + ORT_RETURN_IF(!eval_sess_, "Eval model was not provided. Cannot export a model for inferencing."); // Once finished_training is set to true, will no longer be able to train or evaluate with this module // since the eval session graph will have been modified. @@ -719,7 +708,8 @@ size_t Module::GetTrainingModelInputCount() const noexcept { return train_input_names_.UserInputNames().size(); } -size_t Module::GetEvalModelInputCount() const noexcept { +size_t Module::GetEvalModelInputCount() const { + ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. "); return eval_user_input_count_; } @@ -731,6 +721,7 @@ std::string Module::GetTrainingModelInputName(size_t index) const { } std::string Module::GetEvalModelInputName(size_t index) const { + ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input name. "); ORT_ENFORCE(index < eval_user_input_count_, "Eval input name index out of range. Expected in range [0-", eval_user_input_count_, "). Actual: ", index); @@ -741,7 +732,8 @@ std::pair Module::GetTrainingModelInputs() return train_sess_->GetModelInputs(); } -std::pair Module::GetEvalModelInputs() const noexcept { +std::pair Module::GetEvalModelInputs() const { + ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. "); return eval_sess_->GetModelInputs(); } diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 51b469bae03d1..e4a784b95bcfd 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -116,7 +116,7 @@ struct Module { size_t GetTrainingModelOutputCount() const noexcept; // Returns the output count for eval graph - size_t GetEvalModelOutputCount() const noexcept; + size_t GetEvalModelOutputCount() const; // Returns the output names for train graph std::string GetTrainingModelOutputName(size_t index) const; @@ -151,7 +151,7 @@ struct Module { size_t GetTrainingModelInputCount() const noexcept; // Returns the user input count for eval graph - size_t GetEvalModelInputCount() const noexcept; + size_t GetEvalModelInputCount() const; // Returns the user input name for train graph at given index std::string GetTrainingModelInputName(size_t index) const; @@ -163,7 +163,7 @@ struct Module { std::pair GetTrainingModelInputs() const noexcept; // Returns the input definitions of the Eval model - std::pair GetEvalModelInputs() const noexcept; + std::pair GetEvalModelInputs() const; private: std::unique_ptr train_sess_{nullptr}; diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index 78619947b8b18..a5a856b5fff32 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -37,7 +37,7 @@ size_t TrainingSession::GetTrainingModelOutputCount() const noexcept { return module_->GetTrainingModelOutputCount(); } -size_t TrainingSession::GetEvalModelOutputCount() const noexcept { +size_t TrainingSession::GetEvalModelOutputCount() const { return module_->GetEvalModelOutputCount(); } @@ -45,7 +45,7 @@ std::string TrainingSession::GetTrainingModelOutputName(size_t index) const noex return module_->GetTrainingModelOutputName(index); } -std::string TrainingSession::GetEvalModelOutputName(size_t index) const noexcept { +std::string TrainingSession::GetEvalModelOutputName(size_t index) const { return module_->GetEvalModelOutputName(index); } @@ -53,7 +53,7 @@ size_t TrainingSession::GetTrainingModelInputCount() const noexcept { return module_->GetTrainingModelInputCount(); } -size_t TrainingSession::GetEvalModelInputCount() const noexcept { +size_t TrainingSession::GetEvalModelInputCount() const { return module_->GetEvalModelInputCount(); } @@ -61,7 +61,7 @@ std::string TrainingSession::GetTrainingModelInputName(size_t index) const noexc return module_->GetTrainingModelInputName(index); } -std::string TrainingSession::GetEvalModelInputName(size_t index) const noexcept { +std::string TrainingSession::GetEvalModelInputName(size_t index) const { return module_->GetEvalModelInputName(index); } diff --git a/orttraining/orttraining/training_api/training_session.h b/orttraining/orttraining/training_api/training_session.h index 13b0ae79093de..d2552891a76a4 100644 --- a/orttraining/orttraining/training_api/training_session.h +++ b/orttraining/orttraining/training_api/training_session.h @@ -30,19 +30,19 @@ class TrainingSession { size_t GetTrainingModelOutputCount() const noexcept; - size_t GetEvalModelOutputCount() const noexcept; + size_t GetEvalModelOutputCount() const; std::string GetTrainingModelOutputName(size_t index) const noexcept; - std::string GetEvalModelOutputName(size_t index) const noexcept; + std::string GetEvalModelOutputName(size_t index) const; size_t GetTrainingModelInputCount() const noexcept; - size_t GetEvalModelInputCount() const noexcept; + size_t GetEvalModelInputCount() const; std::string GetTrainingModelInputName(size_t index) const noexcept; - std::string GetEvalModelInputName(size_t index) const noexcept; + std::string GetEvalModelInputName(size_t index) const; Status TrainStep(const RunOptions& run_options, const std::vector& inputs,