Skip to content

Commit

Permalink
Fix training build
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Nov 29, 2023
1 parent 2384587 commit 25faf93
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions orttraining/orttraining/python/orttraining_python_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ void addObjectMethodsForEager(py::module& m);
#ifdef ENABLE_LAZY_TENSOR
void addObjectMethodsForLazyTensor(py::module& m);
#endif
void InitArray();

bool GetDyanmicExecutionProviderHash(
const std::string& ep_shared_lib_path,
Expand Down Expand Up @@ -225,7 +224,6 @@ class TrainingEnvInitialzer {

private:
TrainingEnvInitialzer() {
InitArray();
Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON);
ort_training_env_ = std::make_unique<ORTTrainingPythonEnv>();
}
Expand Down Expand Up @@ -318,7 +316,8 @@ void ORTTrainingRegisterExecutionProviders(InferenceSession* sess, const std::ve
}
}

PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
static bool CreateTrainingPybindStateModule(py::module& m) {
import_array1(false);
m.doc() = "pybind11 stateful interface to ORTTraining";
RegisterExceptions(m);

Expand All @@ -344,6 +343,11 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
#ifdef ENABLE_LAZY_TENSOR
addObjectMethodsForLazyTensor(m);
#endif
}
PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
if (!CreateTrainingPybindStateModule(m)) {
throw pybind11::import_error();
}

m.def("_register_provider_lib", [](const std::string& name,
const std::string& provider_shared_lib_path,
Expand Down

0 comments on commit 25faf93

Please sign in to comment.