Skip to content

Commit

Permalink
add other providers as well
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Aug 5, 2024
1 parent 4ca6fe0 commit 7c436ed
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/python/py/_dll_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ def _is_linux():


def add_onnxruntime_dependency():
"""Add the onnxruntime DLL directory to the DLL search path.
"""Add the onnxruntime shared library dependency.
This function is a no-op on non-Windows platforms.
On Windows, this function adds the onnxruntime DLL directory to the DLL search path.
On Linux, this function loads the onnxruntime shared library and its dependencies
so that they can be found by the dynamic linker.
"""
if _is_windows():
import importlib.util
Expand All @@ -36,11 +38,14 @@ def add_onnxruntime_dependency():
# Load the onnxruntime shared library here since we can find the path in python with ease.
# This avoids needing to know the exact path of the shared library from native code.
ort_package_path = ort_package.submodule_search_locations[0]
ort_lib_path = glob.glob(os.path.join(ort_package_path, "capi", "libonnxruntime.so*"))[0]
_ = ctypes.CDLL(ort_lib_path)
providers_lib_path = glob.glob(os.path.join(ort_package_path, "capi", "libonnxruntime_providers_shared.so*"))
if providers_lib_path:
_ = [ctypes.CDLL(providers_lib_path[i]) for i in range(len(providers_lib_path))]
ort_lib_path = glob.glob(os.path.join(ort_package_path, "capi", "libonnxruntime.so*"))
if not ort_lib_path:
raise ImportError("Could not find the onnxruntime shared library.")

_ = ctypes.CDLL(ort_lib_path[0])

providers_lib_path = glob.glob(os.path.join(ort_package_path, "capi", "libonnxruntime_providers_*.so"))
_ = [ctypes.CDLL(providers_lib_path[i]) for i in range(len(providers_lib_path))]


def add_cuda_dependency():
Expand Down

0 comments on commit 7c436ed

Please sign in to comment.