Skip to content

Commit

Permalink
Update python code
Browse files Browse the repository at this point in the history
  • Loading branch information
jchen351 committed Oct 21, 2024
1 parent 87c51fb commit c95dbce
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
20 changes: 14 additions & 6 deletions onnxruntime/python/onnxruntime_cuda_temp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@ def __exit__(self, exc_type, exc_value, traceback):
os.environ.update(self.original_env)


def get_nvidia_dll_paths():
def get_nvidia_dll_paths() -> str:
# Get the site-packages path where nvidia packages are installed
site_packages_path = site.getsitepackages()[0]
nvidia_path = os.path.join(site_packages_path, "nvidia")

# Collect all directories under site-packages/nvidia that contain .dll files (for Windows)
dll_paths = []
for root, dirs, files in os.walk(nvidia_path):
for root, files in os.walk(nvidia_path):
if any(file.endswith(".dll") for file in files):
dll_paths.append(root)
return os.pathsep.join(dll_paths)


def get_nvidia_so_paths():
def get_nvidia_so_paths() -> str:
# Get the site-packages path where nvidia packages are installed
site_packages_path = site.getsitepackages()[0]
nvidia_path = os.path.join(site_packages_path, "nvidia")

# Collect all directories under site-packages/nvidia that contain .so files (for Linux)
so_paths = []
for root, dirs, files in os.walk(nvidia_path):
for root, files in os.walk(nvidia_path):
if any(file.endswith(".so") for file in files):
so_paths.append(root)
return os.pathsep.join(so_paths)
Expand All @@ -43,7 +43,15 @@ def setup_temp_env_for_ort_cuda():
# Determine platform and set up the environment accordingly
if platform.system() == "Windows": # Windows
nvidia_dlls_path = get_nvidia_dll_paths()
return TemporaryEnv({"PATH": nvidia_dlls_path + os.pathsep + os.environ.get("PATH", "")})
if nvidia_dlls_path:
return TemporaryEnv({"PATH": nvidia_dlls_path + os.pathsep + os.environ.get("PATH")})
else:
return TemporaryEnv({"PATH": os.environ.get("PATH")})
elif platform.system() == "Linux":
nvidia_so_paths = get_nvidia_so_paths()
return TemporaryEnv({"LD_LIBRARY_PATH": nvidia_so_paths + os.pathsep + os.environ.get("LD_LIBRARY_PATH", "")})
if nvidia_so_paths:
return TemporaryEnv({"LD_LIBRARY_PATH": nvidia_so_paths + os.pathsep + os.environ.get("LD_LIBRARY_PATH")})
else:
return TemporaryEnv({"LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH")})
else:
return None
10 changes: 4 additions & 6 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,6 @@ class InferenceSession(Session):
This is the main class used to run a model.
"""

env_manager = None

def __init__(
self,
path_or_bytes: str | bytes | os.PathLike,
Expand Down Expand Up @@ -441,10 +439,9 @@ def __init__(
means execute a node using `CUDAExecutionProvider`
if capable, otherwise execute using `CPUExecutionProvider`.
"""
if device_type == "gpu":
from .onnxruntime_cuda_temp_env import setup_temp_env_for_ort_cuda
from .onnxruntime_cuda_temp_env import setup_temp_env_for_ort_cuda

self.env_manager = setup_temp_env_for_ort_cuda()
self.env_manager = setup_temp_env_for_ort_cuda()
super().__init__()

if isinstance(path_or_bytes, (str, os.PathLike)):
Expand Down Expand Up @@ -581,9 +578,10 @@ def _register_ep_custom_ops(self, session_options, providers, provider_options,
):
C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1])

def __exit__(self):
def __exit__(self, exc_type, exc_value, traceback):
if self.env_manager is not None:
self.env_manager.__exit__()
return False


class IOBinding:
Expand Down

0 comments on commit c95dbce

Please sign in to comment.