Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable export for inference when eval model is loaded from buffer #21422

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,10 @@ public void OptimizerStep(RunOptions options)
/// an inference model if it knows the inference graph outputs. The input inference graph outputs
/// are used to prune the eval model so that the inference model's outputs align with the provided outputs.
/// The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
/// Note that the function re-loads the eval model from the path provided to TrainingSession
/// and expects that this path still be valid.
///
/// This function modifies the eval graph in-place, so after this method is called, the TrainingSession can
/// no longer be used for training. In order to continue training from this point, save the checkpoint state
/// and create a new TrainingSession with the saved checkpoint state.
carzh marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
/// <param name="inferenceModelPath">Path where the inference model should be serialized to.</param>
/// <param name="graphOutputNames">Names of the outputs that are needed in the inference model.</param>
Expand Down
5 changes: 3 additions & 2 deletions java/src/main/java/ai/onnxruntime/OrtTrainingSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -998,8 +998,9 @@ private native void schedulerStep(long apiHandle, long trainingApiHandle, long n
* Exports the evaluation model as a model suitable for inference, setting the desired nodes as
* output nodes.
*
* <p>Note that this method reloads the evaluation model from the path provided to the training
* session, and this path must still be valid.
* <p>Note that this method modifies the eval session in-place; thus, after this method is called,
* the OrtTrainingSession can no longer be trained with. To continue training from this point,
* save the checkpoint and then load it into a new OrtTrainingSession.
*
* @param outputPath The path to write out the inference model.
* @param outputNames The names of the output nodes.
Expand Down
7 changes: 4 additions & 3 deletions objectivec/include/ort_training_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,11 @@ NS_ASSUME_NONNULL_BEGIN
*
* If the training session was provided with an eval model, the training session can generate an inference model if it
* knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the
* inference model's outputs align with the provided outputs. The exported model is saved at the path provided and
* can be used for inferencing with `ORTSession`.
* inference model's outputs align with the provided outputs.
*
* @note The method reloads the eval model from the path provided to the initializer and expects this path to be valid.
* @note This method modifies the eval model graph in-place, so after this method is called, the ORTTrainingSession
* can no longer be used for training. To resume training from this point, save the checkpoint state and create a new
* ORTTrainingSession with the saved checkpoint state.
*
* @param inferenceModelPath The path to the serialized the inference model.
* @param graphOutputNames The names of the outputs that are needed in the inference model.
Expand Down
5 changes: 4 additions & 1 deletion orttraining/orttraining/python/training/api/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,12 @@ def export_model_for_inferencing(
Once training is complete, this function can be used to drop the training specific nodes in the onnx model.
In particular, this function does the following:

- Parse over the training graph and identify nodes that generate the given output names.
- Parse over the eval graph and identify nodes that generate the given output names.
- Drop all subsequent nodes in the graph since they are not relevant to the inference graph.

Once this method is called, training is considered complete and the module can no longer be used for training.
To resume training from this point, save the checkpoint and create a new module from the checkpoint.

Args:
inference_model_uri: The path to the inference model.
graph_output_names: The list of output names that are required for inferencing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from orttraining_test_ort_apis_onnxblock import _get_models

import onnxruntime.training.onnxblock as onnxblock
from onnxruntime import OrtValue, SessionOptions
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from onnxruntime.training import artifacts
from onnxruntime.training.api import CheckpointState, LinearLRScheduler, Module, Optimizer

Expand Down Expand Up @@ -283,6 +283,7 @@ def test_export_model_for_inferencing():
inference_model_file_path = os.path.join(temp_dir, "inference_model.onnx")
model.export_model_for_inferencing(inference_model_file_path, ["output-0"])
assert os.path.exists(inference_model_file_path)
InferenceSession(inference_model_file_path)


def test_cuda_execution_provider():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,75 @@ TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) {
train_model_data);
}

TEST(TrainingCApiTest, LoadONNXModelsFromBufferThenExport) {
auto model_path = MODEL_FOLDER "training_model.onnx";
size_t model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len));
std::vector<uint8_t> train_model_data(model_data_len);
std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(train_model_data.data()), model_data_len);
ASSERT_TRUE(train_model_data.size() == model_data_len);

auto eval_model_path = MODEL_FOLDER "eval_model.onnx";
size_t eval_model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, eval_model_data_len));
std::vector<uint8_t> eval_model_data(eval_model_data_len);
std::ifstream eval_bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary);
eval_bytes_stream.read(reinterpret_cast<char*>(eval_model_data.data()), eval_model_data_len);
ASSERT_TRUE(eval_model_data.size() == eval_model_data_len);

Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(env,
Ort::SessionOptions(),
checkpoint_state,
train_model_data,
eval_model_data);

// randomly selected output name
std::vector<std::string> graph_output_names({"onnx::loss::21273"});
training_session.ExportModelForInferencing(MODEL_FOLDER "inference_model.onnx", graph_output_names);

// Check that the model is a valid inference model by loading into an InferenceSession
std::unique_ptr<Environment> environment;
ASSERT_STATUS_OK(Environment::Create(nullptr, environment));
InferenceSession inference_session = InferenceSession(SessionOptions(), *environment, MODEL_FOLDER "inference_model.onnx");

// Check that you can no longer train or evaluate after exporting. Since passing incorrect inputs will also cause
// TrainStep and EvalStep to throw errors, we check for the error message.
ORT_TRY {
training_session.TrainStep({});
FAIL() << "TrainStep after exporting for inference should have thrown an error.";
}
ORT_CATCH(const Ort::Exception& e) {
ORT_HANDLE_EXCEPTION([&e]() {
ASSERT_THAT(e.what(),
testing::HasSubstr("Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."));
});
}
ORT_CATCH(...) {
FAIL() << "TrainStep after exporting for inference should have thrown an Ort::Exception.";
}

ORT_TRY {
training_session.EvalStep({});
FAIL() << "EvalStep after exporting for inference should have thrown an Ort::Exception.";
}
ORT_CATCH(const Ort::Exception& e) {
ORT_HANDLE_EXCEPTION([&e]() {
ASSERT_THAT(e.what(),
testing::HasSubstr("Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."));
});
}
ORT_CATCH(...) {
FAIL() << "EvalStep after exporting for inference should have thrown an Ort::Exception.";
}

// attempt to retrieve the input & output names of the eval model
ASSERT_THROW(training_session.InputNames(false), Ort::Exception);
ASSERT_THROW(training_session.OutputNames(false), Ort::Exception);
}

TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) {
auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort";
auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,9 @@ struct OrtTrainingApi {
* an inference model if it knows the inference graph outputs. The input inference graph outputs
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
* The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
* \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession
* and expects that this path still be valid.
* \note Note that the function modifies the eval model graph in-place, so after this method is called, the
* OrtTrainingSession can no longer be used for training. To resume training from this point, save the checkpoint
* state and create a new OrtTrainingSession with the updated eval model.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] inference_model_path Path where the inference model should be serialized to.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,9 @@ class TrainingSession : public detail::Base<OrtTrainingSession> {
* an inference model if it knows the inference graph outputs. The input inference graph outputs
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
* The exported model is saved at the path provided and can be used for inferencing with Ort::Session.
* \note Note that the function re-loads the eval model from the path provided to Ort::TrainingSession
* and expects that this path still be valid.
* \note Note that the function modifies the eval model graph in-place, so after this method is called, the
* OrtTrainingSession can no longer be used for training. To resume training from this point, save the checkpoint
* state and create a new OrtTrainingSession with the updated eval model.
*
* \param[in] inference_model_path Path where the inference model should be serialized to.
* \param[in] graph_output_names Names of the outputs that are needed in the inference model.
Expand Down
Loading
Loading