From 25faf93ec00346b278d69f526c20be459326a1ac Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 29 Nov 2023 09:34:47 -0800 Subject: [PATCH] Fix training build --- .../orttraining/python/orttraining_python_module.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 4d1db7334f280..61885052e8a73 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -45,7 +45,6 @@ void addObjectMethodsForEager(py::module& m); #ifdef ENABLE_LAZY_TENSOR void addObjectMethodsForLazyTensor(py::module& m); #endif -void InitArray(); bool GetDyanmicExecutionProviderHash( const std::string& ep_shared_lib_path, @@ -225,7 +224,6 @@ class TrainingEnvInitialzer { private: TrainingEnvInitialzer() { - InitArray(); Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON); ort_training_env_ = std::make_unique(); } @@ -318,7 +316,8 @@ void ORTTrainingRegisterExecutionProviders(InferenceSession* sess, const std::ve } } -PYBIND11_MODULE(onnxruntime_pybind11_state, m) { +static bool CreateTrainingPybindStateModule(py::module& m) { + import_array1(false); m.doc() = "pybind11 stateful interface to ORTTraining"; RegisterExceptions(m); @@ -344,6 +343,11 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { #ifdef ENABLE_LAZY_TENSOR addObjectMethodsForLazyTensor(m); #endif +} +PYBIND11_MODULE(onnxruntime_pybind11_state, m) { + if (!CreateTrainingPybindStateModule(m)) { + throw pybind11::import_error(); + } m.def("_register_provider_lib", [](const std::string& name, const std::string& provider_shared_lib_path,