Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable exporting for inference when loading from buffer without behavior changes #21601

Merged
merged 4 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::
return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold);
}

Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
Status Model::LoadFromBytes(int count, const void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
const bool result = model_proto.ParseFromArray(p_bytes, count);
if (!result) {
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class Model {
const ModelOptions& options = {});

// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
static common::Status LoadFromBytes(int count, void* pBytes,
static common::Status LoadFromBytes(int count, const void* pBytes,
/*out*/ ONNX_NAMESPACE::ModelProto& model_proto);

// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,41 @@ TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) {
train_model_data);
}

TEST(TrainingCApiTest, LoadONNXModelsFromBufferThenExport) {
auto model_path = MODEL_FOLDER "training_model.onnx";
size_t model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len));
std::vector<uint8_t> train_model_data(model_data_len);
std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(train_model_data.data()), model_data_len);
ASSERT_TRUE(train_model_data.size() == model_data_len);

auto eval_model_path = MODEL_FOLDER "eval_model.onnx";
size_t eval_model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, eval_model_data_len));
std::vector<uint8_t> eval_model_data(eval_model_data_len);
std::ifstream eval_bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary);
eval_bytes_stream.read(reinterpret_cast<char*>(eval_model_data.data()), eval_model_data_len);
ASSERT_TRUE(eval_model_data.size() == eval_model_data_len);

Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(env,
Ort::SessionOptions(),
checkpoint_state,
train_model_data,
eval_model_data);

// randomly selected output name
std::vector<std::string> graph_output_names({"onnx::loss::21273"});
training_session.ExportModelForInferencing(MODEL_FOLDER "inference_model.onnx", graph_output_names);

// Check that the model is a valid inference model by loading into an InferenceSession
std::unique_ptr<Environment> environment;
ASSERT_STATUS_OK(Environment::Create(nullptr, environment));
InferenceSession inference_session = InferenceSession(SessionOptions(), *environment, MODEL_FOLDER "inference_model.onnx");
}

TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) {
auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort";
auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort";
Expand Down
15 changes: 11 additions & 4 deletions orttraining/orttraining/training_api/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,12 @@ Module::Module(const ModelIdentifiers& model_identifiers,
eval_user_input_count_ = eval_user_input_names.size();
eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end());

// Keep a copy of the eval model path to be able to later export the model for inferencing.
// Keep a copy of the eval model path or buffer to be able to later export the model for inferencing.
// The inference model will be reconstructed from the eval model.
// TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer.
if (std::holds_alternative<std::optional<std::string>>(model_identifiers.eval_model)) {
eval_model_path_ = std::get<std::optional<std::string>>(model_identifiers.eval_model);
} else if (std::holds_alternative<gsl::span<const uint8_t>>(model_identifiers.eval_model)) {
eval_model_buffer_ = std::get<gsl::span<const uint8_t>>(model_identifiers.eval_model);
}
}

Expand Down Expand Up @@ -658,11 +659,17 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path
gsl::span<const std::string> graph_output_names) const {
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot export the model with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(),
ORT_RETURN_IF(!eval_sess_ || (!eval_model_path_.has_value() && !eval_model_buffer_.has_value()),
"Eval model was not provided. Cannot export a model for inferencing.");

ONNX_NAMESPACE::ModelProto eval_model;
ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model));
if (eval_model_path_.has_value()) {
ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model));
} else if (eval_model_buffer_.has_value()) {
int eval_model_buffer_size = static_cast<int>(eval_model_buffer_.value().size());
const void* eval_model_buffer_ptr = static_cast<const void*>(eval_model_buffer_.value().data());
ORT_THROW_IF_ERROR(Model::LoadFromBytes(eval_model_buffer_size, eval_model_buffer_ptr, eval_model));
}

// Clone the eval mode into an inference onnxruntime::Model.
std::shared_ptr<Model> inference_model;
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/training_api/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ struct Module {

bool accumulate_gradient_ = false;
std::optional<std::string> eval_model_path_;
std::optional<gsl::span<const uint8_t>> eval_model_buffer_;
size_t eval_user_input_count_{0U};
};

Expand Down
Loading