diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index ee4d9f9154971..d38c1ace7d7a8 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -646,7 +646,7 @@ Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem:: return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold); } -Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { +Status Model::LoadFromBytes(int count, const void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { const bool result = model_proto.ParseFromArray(p_bytes, count); if (!result) { return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed."); diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 728af727ac83b..ea34dba889277 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -234,7 +234,7 @@ class Model { const ModelOptions& options = {}); // 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks - static common::Status LoadFromBytes(int count, void* pBytes, + static common::Status LoadFromBytes(int count, const void* pBytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto); // 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index 8f25e1e4c92b8..cff060134e679 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -265,6 +265,41 @@ TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) { train_model_data); } +TEST(TrainingCApiTest, LoadONNXModelsFromBufferThenExport) { + auto model_path = MODEL_FOLDER "training_model.onnx"; + size_t model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len)); + std::vector train_model_data(model_data_len); + std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len); + ASSERT_TRUE(train_model_data.size() == model_data_len); + + auto eval_model_path = MODEL_FOLDER "eval_model.onnx"; + size_t eval_model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, eval_model_data_len)); + std::vector eval_model_data(eval_model_data_len); + std::ifstream eval_bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary); + eval_bytes_stream.read(reinterpret_cast(eval_model_data.data()), eval_model_data_len); + ASSERT_TRUE(eval_model_data.size() == eval_model_data_len); + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_data, + eval_model_data); + + // randomly selected output name + std::vector graph_output_names({"onnx::loss::21273"}); + training_session.ExportModelForInferencing(MODEL_FOLDER "inference_model.onnx", graph_output_names); + + // Check that the model is a valid inference model by loading into an InferenceSession + std::unique_ptr environment; + ASSERT_STATUS_OK(Environment::Create(nullptr, environment)); + InferenceSession inference_session = InferenceSession(SessionOptions(), *environment, MODEL_FOLDER "inference_model.onnx"); +} + TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) { auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort"; auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort"; diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index dc724fbae48eb..939e1de334e52 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -412,11 +412,12 @@ Module::Module(const ModelIdentifiers& model_identifiers, eval_user_input_count_ = eval_user_input_names.size(); eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); - // Keep a copy of the eval model path to be able to later export the model for inferencing. + // Keep a copy of the eval model path or buffer to be able to later export the model for inferencing. // The inference model will be reconstructed from the eval model. - // TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer. if (std::holds_alternative>(model_identifiers.eval_model)) { eval_model_path_ = std::get>(model_identifiers.eval_model); + } else if (std::holds_alternative>(model_identifiers.eval_model)) { + eval_model_buffer_ = std::get>(model_identifiers.eval_model); } } @@ -658,11 +659,17 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path gsl::span graph_output_names) const { ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, "Cannot export the model with a nominal state. Please load the model parameters first."); - ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(), + ORT_RETURN_IF(!eval_sess_ || (!eval_model_path_.has_value() && !eval_model_buffer_.has_value()), "Eval model was not provided. Cannot export a model for inferencing."); ONNX_NAMESPACE::ModelProto eval_model; - ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model)); + if (eval_model_path_.has_value()) { + ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model)); + } else if (eval_model_buffer_.has_value()) { + int eval_model_buffer_size = static_cast(eval_model_buffer_.value().size()); + const void* eval_model_buffer_ptr = static_cast(eval_model_buffer_.value().data()); + ORT_THROW_IF_ERROR(Model::LoadFromBytes(eval_model_buffer_size, eval_model_buffer_ptr, eval_model)); + } // Clone the eval mode into an inference onnxruntime::Model. std::shared_ptr inference_model; diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 917887404217f..f4d894f33516a 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -198,6 +198,7 @@ struct Module { bool accumulate_gradient_ = false; std::optional eval_model_path_; + std::optional> eval_model_buffer_; size_t eval_user_input_count_{0U}; };