Skip to content

Commit

Permalink
added test + implementation where eval model is stored as buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
carzh committed Aug 2, 2024
1 parent 4b8f6dc commit 9cef143
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 6 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::
return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold);
}

Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
Status Model::LoadFromBytes(int count, const void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
const bool result = model_proto.ParseFromArray(p_bytes, count);
if (!result) {
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class Model {
const ModelOptions& options = {});

// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
static common::Status LoadFromBytes(int count, void* pBytes,
static common::Status LoadFromBytes(int count, const void* pBytes,
/*out*/ ONNX_NAMESPACE::ModelProto& model_proto);

// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,41 @@ TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) {
train_model_data);
}

TEST(TrainingCApiTest, LoadONNXModelsFromBufferThenExport) {
auto model_path = MODEL_FOLDER "training_model.onnx";
size_t model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len));
std::vector<uint8_t> train_model_data(model_data_len);
std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(train_model_data.data()), model_data_len);
ASSERT_TRUE(train_model_data.size() == model_data_len);

auto eval_model_path = MODEL_FOLDER "eval_model.onnx";
size_t eval_model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, eval_model_data_len));
std::vector<uint8_t> eval_model_data(eval_model_data_len);
std::ifstream eval_bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary);
eval_bytes_stream.read(reinterpret_cast<char*>(eval_model_data.data()), eval_model_data_len);
ASSERT_TRUE(eval_model_data.size() == eval_model_data_len);

Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(env,
Ort::SessionOptions(),
checkpoint_state,
train_model_data,
eval_model_data);

// randomly selected output name
std::vector<std::string> graph_output_names({"onnx::loss::21273"});
training_session.ExportModelForInferencing(MODEL_FOLDER "inference_model.onnx", graph_output_names);

// Check that the model is a valid inference model by loading into an InferenceSession
std::unique_ptr<Environment> environment;
ASSERT_STATUS_OK(Environment::Create(nullptr, environment));
InferenceSession inference_session = InferenceSession(SessionOptions(), *environment, MODEL_FOLDER "inference_model.onnx");
}

TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) {
auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort";
auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort";
Expand Down
17 changes: 13 additions & 4 deletions orttraining/orttraining/training_api/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,14 @@ Module::Module(const ModelIdentifiers& model_identifiers,
eval_user_input_count_ = eval_user_input_names.size();
eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end());

// Keep a copy of the eval model path to be able to later export the model for inferencing.
// Keep a copy of the eval model path or buffer to be able to later export the model for inferencing.
// The inference model will be reconstructed from the eval model.
// TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer.
if (std::holds_alternative<std::optional<std::string>>(model_identifiers.eval_model)) {
eval_model_path_ = std::get<std::optional<std::string>>(model_identifiers.eval_model);
}
else if (std::holds_alternative<gsl::span<const uint8_t>>(model_identifiers.eval_model)) {

Check warning on line 420 in orttraining/orttraining/training_api/module.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 An else should appear on the same line as the preceding } [whitespace/newline] [4] Raw Output: orttraining/orttraining/training_api/module.cc:420: An else should appear on the same line as the preceding } [whitespace/newline] [4]

Check warning on line 420 in orttraining/orttraining/training_api/module.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 If an else has a brace on one side, it should have it on both [readability/braces] [5] Raw Output: orttraining/orttraining/training_api/module.cc:420: If an else has a brace on one side, it should have it on both [readability/braces] [5]
eval_model_buffer_ = std::get<gsl::span<const uint8_t>>(model_identifiers.eval_model);
}
}

Module::~Module() {
Expand Down Expand Up @@ -658,11 +660,18 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path
gsl::span<const std::string> graph_output_names) const {
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot export the model with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(),
ORT_RETURN_IF(!eval_sess_ || (!eval_model_path_.has_value() && !eval_model_buffer_.has_value()),
"Eval model was not provided. Cannot export a model for inferencing.");

ONNX_NAMESPACE::ModelProto eval_model;
ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model));
if (eval_model_path_.has_value()) {
ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model));
}
else if (eval_model_buffer_.has_value()) {

Check warning on line 670 in orttraining/orttraining/training_api/module.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 An else should appear on the same line as the preceding } [whitespace/newline] [4] Raw Output: orttraining/orttraining/training_api/module.cc:670: An else should appear on the same line as the preceding } [whitespace/newline] [4]

Check warning on line 670 in orttraining/orttraining/training_api/module.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 If an else has a brace on one side, it should have it on both [readability/braces] [5] Raw Output: orttraining/orttraining/training_api/module.cc:670: If an else has a brace on one side, it should have it on both [readability/braces] [5]
int eval_model_buffer_size = static_cast<int>(eval_model_buffer_.value().size());
const void* eval_model_buffer_ptr = static_cast<const void*>(eval_model_buffer_.value().data());
ORT_THROW_IF_ERROR(Model::LoadFromBytes(eval_model_buffer_size, eval_model_buffer_ptr, eval_model));
}

// Clone the eval mode into an inference onnxruntime::Model.
std::shared_ptr<Model> inference_model;
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/training_api/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ struct Module {

bool accumulate_gradient_ = false;
std::optional<std::string> eval_model_path_;
std::optional<gsl::span<const uint8_t>> eval_model_buffer_;
size_t eval_user_input_count_{0U};
};

Expand Down

0 comments on commit 9cef143

Please sign in to comment.