From 7c436edc827b41b2e63afc972fdf7d68bc47f223 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 29 Jul 2024 19:45:53 +0000 Subject: [PATCH] add other providers as well --- src/python/py/_dll_directory.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/python/py/_dll_directory.py b/src/python/py/_dll_directory.py index 125afd5e6..d82c6a7c4 100644 --- a/src/python/py/_dll_directory.py +++ b/src/python/py/_dll_directory.py @@ -13,9 +13,11 @@ def _is_linux(): def add_onnxruntime_dependency(): - """Add the onnxruntime DLL directory to the DLL search path. + """Add the onnxruntime shared library dependency. - This function is a no-op on non-Windows platforms. + On Windows, this function adds the onnxruntime DLL directory to the DLL search path. + On Linux, this function loads the onnxruntime shared library and its dependencies + so that they can be found by the dynamic linker. """ if _is_windows(): import importlib.util @@ -36,11 +38,14 @@ def add_onnxruntime_dependency(): # Load the onnxruntime shared library here since we can find the path in python with ease. # This avoids needing to know the exact path of the shared library from native code. ort_package_path = ort_package.submodule_search_locations[0] - ort_lib_path = glob.glob(os.path.join(ort_package_path, "capi", "libonnxruntime.so*"))[0] - _ = ctypes.CDLL(ort_lib_path) - providers_lib_path = glob.glob(os.path.join(ort_package_path, "capi", "libonnxruntime_providers_shared.so*")) - if providers_lib_path: - _ = [ctypes.CDLL(providers_lib_path[i]) for i in range(len(providers_lib_path))] + ort_lib_path = glob.glob(os.path.join(ort_package_path, "capi", "libonnxruntime.so*")) + if not ort_lib_path: + raise ImportError("Could not find the onnxruntime shared library.") + + _ = ctypes.CDLL(ort_lib_path[0]) + + providers_lib_path = glob.glob(os.path.join(ort_package_path, "capi", "libonnxruntime_providers_*.so")) + _ = [ctypes.CDLL(providers_lib_path[i]) for i in range(len(providers_lib_path))] def add_cuda_dependency():