Skip to content

Commit

Permalink
[TensorRT EP] Fallback to CUDA EP if it's explicitly assigned (#17535)
Browse files Browse the repository at this point in the history
### Description
* TensorRT EP can fall back to CUDA EP if it's explicitly assigned
* MIGraphX can fall back to ROCM if it's explicitly assigned

Test cases:
| When user specifies providers= | self._fallback_providers= |
| ------------------------------------------------------------ |
------------------------------------------------- |
| ["TensorrtExecutionProvider", "CUDAExecutionProvider"] |
["CUDAExecutionProvider", "CPUExecutionProvider"] |
| ["TensorrtExecutionProvider",("CUDAExecutionProvider", cuda_options)]
| ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| ["TensorrtExecutionProvider"] | ["CPUExecutionProvider"] |
| [("TensorrtExecutionProvider", trt_options)] |
["CPUExecutionProvider"] |
| [("TensorrtExecutionProvider", trt_options), ("CUDAExecutionProvider",
cuda_options)] | ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| ["TensorrtExecutionProvider", "CPUExecutionProvider"] |
["CPUExecutionProvider"] |





### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Apply comments of #17394
and unify the logic to [MIGraphX, ROCM]
  • Loading branch information
yf711 authored Sep 15, 2023
1 parent efd416b commit 705f8a3
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,10 @@ def __init__(
except (ValueError, RuntimeError) as e:
if self._enable_fallback:
try:
print("*************** EP Error ***************")
print(f"EP Error {e} when using {providers}")
print(f"Falling back to {self._fallback_providers} and retrying.")
print("****************************************")
self._create_inference_session(self._fallback_providers, None)
# Fallback only once.
self.disable_fallback()
Expand All @@ -434,11 +436,26 @@ def __init__(
def _create_inference_session(self, providers, provider_options, disabled_optimizers=None):
available_providers = C.get_available_providers()

# Tensorrt can fall back to CUDA. All others fall back to CPU.
# Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU.
if "TensorrtExecutionProvider" in available_providers:
self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
if any(
provider == "CUDAExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider")
for provider in providers
):
self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
self._fallback_providers = ["CPUExecutionProvider"]
# MIGraphX can fall back to ROCM if it's explicitly assigned. All others fall back to CPU.
elif "MIGraphXExecutionProvider" in available_providers:
self._fallback_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
if any(
provider == "ROCMExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "ROCMExecutionProvider")
for provider in providers
):
self._fallback_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
else:
self._fallback_providers = ["CPUExecutionProvider"]
else:
self._fallback_providers = ["CPUExecutionProvider"]

Expand Down

0 comments on commit 705f8a3

Please sign in to comment.