Skip to content

Commit

Permalink
Update gradient ops tests (#18783)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
TrainingSession has been deprecated for a while now, but the gradient
ops tests are still using training session. This PR updates these tests
to use inference session instead of training session.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This will enable us to remove all the training session related
deprecated code from the repo.
  • Loading branch information
askhade authored Dec 13, 2023
1 parent 17eaf9b commit 487abcd
Showing 1 changed file with 4 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "core/framework/kernel_type_str_resolver.h"
#include "core/session/inference_session.h"

#include "orttraining/core/session/training_session.h"
#include "orttraining/core/framework/gradient_graph_builder.h"
#include "orttraining/core/graph/gradient_config.h"

Expand Down Expand Up @@ -76,7 +75,7 @@ void GradientOpTester::Run(int output_index_to_use_as_loss,
}
}

onnxruntime::training::TrainingSession session_object{so, GetEnvironment()};
onnxruntime::InferenceSession session_object{so, GetEnvironment()};

ASSERT_TRUE(!execution_providers->empty()) << "Empty execution providers vector.";
std::string provider_types;
Expand All @@ -102,7 +101,7 @@ void GradientOpTester::Run(int output_index_to_use_as_loss,

has_run = true;

ExecuteModel<onnxruntime::training::TrainingSession>(
ExecuteModel<onnxruntime::InferenceSession>(
model, session_object, ExpectResult::kExpectSuccess, "", nullptr, feeds, output_names, provider_types);
} else {
for (const std::string& provider_type : all_provider_types) {
Expand Down Expand Up @@ -158,11 +157,11 @@ void GradientOpTester::Run(int output_index_to_use_as_loss,
continue;

has_run = true;
onnxruntime::training::TrainingSession session_object{so, GetEnvironment()};
onnxruntime::InferenceSession session_object{so, GetEnvironment()};

EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());

ExecuteModel<onnxruntime::training::TrainingSession>(
ExecuteModel<onnxruntime::InferenceSession>(
model, session_object, ExpectResult::kExpectSuccess, "", nullptr, feeds, output_names, provider_type);
}
}
Expand Down

0 comments on commit 487abcd

Please sign in to comment.