Skip to content

Commit

Permalink
added working unit test + additional throw statements + cleaned up
Browse files Browse the repository at this point in the history
  • Loading branch information
carzh committed Jul 19, 2024
1 parent 46734ae commit 2ed4dfe
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from orttraining_test_ort_apis_onnxblock import _get_models

import onnxruntime.training.onnxblock as onnxblock
from onnxruntime import OrtValue, SessionOptions, InferenceSession
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from onnxruntime.training import artifacts
from onnxruntime.training.api import CheckpointState, LinearLRScheduler, Module, Optimizer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,6 @@ namespace {
constexpr int64_t TOTAL_STEP_COUNT = 100;
constexpr float INITIAL_LR = 1e-3f;

std::vector<uint8_t> ReadFileIntoBuffer(const std::string& file_path) {
size_t num_bytes = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(file_path.c_str(), num_bytes));
std::vector<uint8_t> buffer(num_bytes);

std::ifstream bytes_stream(file_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(buffer.data()), num_bytes);

ASSERT_TRUE(bytes_stream);

return buffer;
}

/**
* @brief Create a Fake Optimizer Checkpoint State On CPU.
*
Expand Down Expand Up @@ -151,66 +138,6 @@ void TestModuleExport(const std::vector<std::shared_ptr<IExecutionProvider>>& pr
RunInferenceSession(*env, inference_model_path);
}

void TestModuleExportFromBuffer(const std::vector<std::shared_ptr<IExecutionProvider>>& providers) {
auto training_model_uri = MODEL_FOLDER "training_model.onnx";
auto eval_model_uri = MODEL_FOLDER "eval_model.onnx";

onnxruntime::training::api::CheckpointState state;
auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt";
// Load checkpoint, eval model, and training model into buffers
std::vector<uint8_t> checkpoint_bytes = ReadFileIntoBuffer(checkpoint_to_load_path);
std::vector<uint8_t> training_model_bytes = ReadFileIntoBuffer(training_model_uri);
std::vector<uint8_t> eval_model_bytes = ReadFileIntoBuffer(eval_model_uri);

ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpointFromBuffer(checkpoint_bytes, state));

// load training and eval model into buffers
std::unique_ptr<Environment> env;
ASSERT_STATUS_OK(Environment::Create(nullptr, env));
auto model_identifier = ModelIdentifiers(training_model_bytes,
std::optional<std::vector<uint8_t>>(eval_model_bytes),
std::nullopt);
auto model = std::make_unique<onnxruntime::training::api::Module>(
model_identifier, &state, onnxruntime::SessionOptions(),
*env, providers);

auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir");
if (Env::Default().FolderExists(test_dir)) {
ORT_ENFORCE(Env::Default().DeleteFolder(test_dir).IsOK());
}
onnxruntime::test::TemporaryDirectory tmp_dir{test_dir};
PathString inference_model_path{
ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("inference_model.onnx"))};

std::vector<std::string> graph_output_names({"output-0"});
ASSERT_STATUS_OK(model->ExportModelForInferencing(ToUTF8String(inference_model_path), graph_output_names));

// Load model
ONNX_NAMESPACE::ModelProto eval_model;
ONNX_NAMESPACE::ModelProto inference_model;
ORT_THROW_IF_ERROR(Model::Load(eval_model_uri, eval_model));
ORT_THROW_IF_ERROR(Model::Load(inference_model_path, inference_model));

// Check it has only one graph input
ASSERT_EQ(eval_model.graph().input().size(), 6);
ASSERT_EQ(inference_model.graph().input().size(), 1);
ASSERT_EQ(inference_model.graph().input()[0].name(), "input-0");

// Check that it does not have any node which has op type SoftmaxCrossEntropyLoss
auto softmaxceloss_node_found = [](auto& model) -> bool {
for (auto& node : model.graph().node()) {
if (node.op_type() == "SoftmaxCrossEntropyLoss") {
return true;
}
}
return false;
};
ASSERT_EQ(softmaxceloss_node_found(eval_model), true);
ASSERT_EQ(softmaxceloss_node_found(inference_model), false);

RunInferenceSession(*env, inference_model_path);
}

void TestModuleExportWithExternalData(const std::vector<std::shared_ptr<IExecutionProvider>>& providers) {
auto training_model_uri = MODEL_FOLDER "training_model.onnx";
auto eval_model_uri = MODEL_FOLDER "eval_model.onnx";
Expand Down Expand Up @@ -568,11 +495,6 @@ TEST(TrainingApiTest, ModuleExportModelForInferencingCPU) {
TestModuleExport(providers);
}

TEST(TrainingApiTest, ModuleFromBufferExportModelForInferencingCPU) {
std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCpuExecutionProvider()};
TestModuleExportFromBuffer(providers);
}

TEST(TrainingApiTest, ModuleExportModelForInferencingCPU_WithExternalData) {
std::vector<std::shared_ptr<IExecutionProvider>> providers{onnxruntime::test::DefaultCpuExecutionProvider()};
TestModuleExportWithExternalData(providers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,75 @@ TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) {
train_model_data);
}

TEST(TrainingCApiTest, LoadONNXModelsFromBufferThenExport) {
auto model_path = MODEL_FOLDER "training_model.onnx";
size_t model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len));
std::vector<uint8_t> train_model_data(model_data_len);
std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(train_model_data.data()), model_data_len);
ASSERT_TRUE(train_model_data.size() == model_data_len);

auto eval_model_path = MODEL_FOLDER "eval_model.onnx";
size_t eval_model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, eval_model_data_len));
std::vector<uint8_t> eval_model_data(eval_model_data_len);
std::ifstream eval_bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary);
eval_bytes_stream.read(reinterpret_cast<char*>(eval_model_data.data()), eval_model_data_len);
ASSERT_TRUE(eval_model_data.size() == eval_model_data_len);

Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(env,
Ort::SessionOptions(),
checkpoint_state,
train_model_data,
eval_model_data);

// randomly selected output name
std::vector<std::string> graph_output_names({"onnx::loss::21273"});
training_session.ExportModelForInferencing(MODEL_FOLDER "inference_model.onnx", graph_output_names);

// Check that the model is a valid inference model by loading into an InferenceSession
std::unique_ptr<Environment> environment;
ASSERT_STATUS_OK(Environment::Create(nullptr, environment));
InferenceSession inference_session = InferenceSession(SessionOptions(), *environment, MODEL_FOLDER "inference_model.onnx");

// Check that you can no longer train or evaluate after exporting. Since passing incorrect inputs will also cause
// TrainStep and EvalStep to throw errors, we check for the error message.
ORT_TRY {
training_session.TrainStep({});
FAIL() << "TrainStep after exporting for inference should have thrown an error.";
}
ORT_CATCH(const Ort::Exception& e) {
ORT_HANDLE_EXCEPTION([&e]() {
ASSERT_THAT(e.what(),
testing::HasSubstr("Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."));
});
}
ORT_CATCH(...) {
FAIL() << "TrainStep after exporting for inference should have thrown an Ort::Exception.";
}

ORT_TRY {
training_session.EvalStep({});
FAIL() << "EvalStep after exporting for inference should have thrown an Ort::Exception.";
}
ORT_CATCH(const Ort::Exception& e) {
ORT_HANDLE_EXCEPTION([&e]() {
ASSERT_THAT(e.what(),
testing::HasSubstr("Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."));
});
}
ORT_CATCH(...) {
FAIL() << "EvalStep after exporting for inference should have thrown an Ort::Exception.";
}

// attempt to retrieve the input & output names of the eval model
ASSERT_THROW(training_session.InputNames(false), Ort::Exception);
ASSERT_THROW(training_session.OutputNames(false), Ort::Exception);
}

TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) {
auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort";
auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort";
Expand Down
26 changes: 9 additions & 17 deletions orttraining/orttraining/training_api/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,6 @@ Status RemoveUnusedNodes(Graph& inference_graph, InlinedVector<const NodeArg*>&
return Status::OK();
}

//TODO: REMOVE THIS METHOD BEFORE YOUR PR ITS JUST FOR DEBUGGING PURPOSES
Status RemoveThisMethodBeforeYourPR(Graph& inference_graph) {
GraphViewer graph_viewer(inference_graph);
const auto node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (size_t idx = node_indices.size(); idx > 0; --idx) {
const NodeIndex node_index = idx - 1;
auto* node = inference_graph.GetNode(node_index);
if (node->Name().empty()) {
inference_graph.RemoveNode(node_index);
}
}

return Status::OK();
}
Status TransformModelOutputsForInference(Graph& inference_graph,
gsl::span<const std::string> inference_graph_outputs) {
// Model is updated to remove any outputs that are not defined in inference_graph_outputs. Nodes
Expand Down Expand Up @@ -449,7 +435,8 @@ size_t Module::GetTrainingModelOutputCount() const noexcept {
return train_output_names_.size();
}

size_t Module::GetEvalModelOutputCount() const noexcept {
size_t Module::GetEvalModelOutputCount() const {
ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. ");
return eval_output_names_.size();
}

Expand All @@ -459,6 +446,7 @@ std::string Module::GetTrainingModelOutputName(size_t index) const {
}

std::string Module::GetEvalModelOutputName(size_t index) const {
ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. ");
ORT_ENFORCE(index < eval_output_names_.size(), "Eval output name index out of range. Expected in range [0-",
eval_output_names_.size(), "). Actual: ", index);
return eval_output_names_.at(index);
Expand Down Expand Up @@ -682,6 +670,7 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path
gsl::span<const std::string> graph_output_names) {
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot export the model with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(!eval_sess_, "Eval model was not provided. Cannot export a model for inferencing.");

// Once finished_training is set to true, will no longer be able to train or evaluate with this module
// since the eval session graph will have been modified.
Expand Down Expand Up @@ -719,7 +708,8 @@ size_t Module::GetTrainingModelInputCount() const noexcept {
return train_input_names_.UserInputNames().size();
}

size_t Module::GetEvalModelInputCount() const noexcept {
size_t Module::GetEvalModelInputCount() const {
ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. ");
return eval_user_input_count_;
}

Expand All @@ -731,6 +721,7 @@ std::string Module::GetTrainingModelInputName(size_t index) const {
}

std::string Module::GetEvalModelInputName(size_t index) const {
ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel input name. ");
ORT_ENFORCE(index < eval_user_input_count_,
"Eval input name index out of range. Expected in range [0-", eval_user_input_count_, "). Actual: ",
index);
Expand All @@ -741,7 +732,8 @@ std::pair<common::Status, const InputDefList*> Module::GetTrainingModelInputs()
return train_sess_->GetModelInputs();
}

std::pair<common::Status, const InputDefList*> Module::GetEvalModelInputs() const noexcept {
std::pair<common::Status, const InputDefList*> Module::GetEvalModelInputs() const {
ORT_ENFORCE(!finished_training_, "Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. ");
return eval_sess_->GetModelInputs();
}

Expand Down
6 changes: 3 additions & 3 deletions orttraining/orttraining/training_api/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ struct Module {
size_t GetTrainingModelOutputCount() const noexcept;

// Returns the output count for eval graph
size_t GetEvalModelOutputCount() const noexcept;
size_t GetEvalModelOutputCount() const;

// Returns the output names for train graph
std::string GetTrainingModelOutputName(size_t index) const;
Expand Down Expand Up @@ -151,7 +151,7 @@ struct Module {
size_t GetTrainingModelInputCount() const noexcept;

// Returns the user input count for eval graph
size_t GetEvalModelInputCount() const noexcept;
size_t GetEvalModelInputCount() const;

// Returns the user input name for train graph at given index
std::string GetTrainingModelInputName(size_t index) const;
Expand All @@ -163,7 +163,7 @@ struct Module {
std::pair<common::Status, const InputDefList*> GetTrainingModelInputs() const noexcept;

// Returns the input definitions of the Eval model
std::pair<common::Status, const InputDefList*> GetEvalModelInputs() const noexcept;
std::pair<common::Status, const InputDefList*> GetEvalModelInputs() const;

private:
std::unique_ptr<onnxruntime::InferenceSession> train_sess_{nullptr};
Expand Down
8 changes: 4 additions & 4 deletions orttraining/orttraining/training_api/training_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,31 +37,31 @@ size_t TrainingSession::GetTrainingModelOutputCount() const noexcept {
return module_->GetTrainingModelOutputCount();
}

size_t TrainingSession::GetEvalModelOutputCount() const noexcept {
size_t TrainingSession::GetEvalModelOutputCount() const {
return module_->GetEvalModelOutputCount();
}

std::string TrainingSession::GetTrainingModelOutputName(size_t index) const noexcept {
return module_->GetTrainingModelOutputName(index);
}

std::string TrainingSession::GetEvalModelOutputName(size_t index) const noexcept {
std::string TrainingSession::GetEvalModelOutputName(size_t index) const {
return module_->GetEvalModelOutputName(index);
}

size_t TrainingSession::GetTrainingModelInputCount() const noexcept {
return module_->GetTrainingModelInputCount();
}

size_t TrainingSession::GetEvalModelInputCount() const noexcept {
size_t TrainingSession::GetEvalModelInputCount() const {
return module_->GetEvalModelInputCount();
}

std::string TrainingSession::GetTrainingModelInputName(size_t index) const noexcept {
return module_->GetTrainingModelInputName(index);
}

std::string TrainingSession::GetEvalModelInputName(size_t index) const noexcept {
std::string TrainingSession::GetEvalModelInputName(size_t index) const {
return module_->GetEvalModelInputName(index);
}

Expand Down
8 changes: 4 additions & 4 deletions orttraining/orttraining/training_api/training_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ class TrainingSession {

size_t GetTrainingModelOutputCount() const noexcept;

size_t GetEvalModelOutputCount() const noexcept;
size_t GetEvalModelOutputCount() const;

std::string GetTrainingModelOutputName(size_t index) const noexcept;

std::string GetEvalModelOutputName(size_t index) const noexcept;
std::string GetEvalModelOutputName(size_t index) const;

size_t GetTrainingModelInputCount() const noexcept;

size_t GetEvalModelInputCount() const noexcept;
size_t GetEvalModelInputCount() const;

std::string GetTrainingModelInputName(size_t index) const noexcept;

std::string GetEvalModelInputName(size_t index) const noexcept;
std::string GetEvalModelInputName(size_t index) const;

Status TrainStep(const RunOptions& run_options,
const std::vector<OrtValue>& inputs,
Expand Down

0 comments on commit 2ed4dfe

Please sign in to comment.