Skip to content

Commit

Permalink
Fix cuda fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 20, 2024
1 parent 34cd2e8 commit aef98d3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
16 changes: 12 additions & 4 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,18 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi

# Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU.
if "TensorrtExecutionProvider" in available_providers:
if providers and any(
provider == "CUDAExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider")
for provider in providers
if (
providers
and any(
provider == "CUDAExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider")
for provider in providers
)
and any(
provider == "TensorrtExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
for provider in providers
)
):
self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
Expand Down
26 changes: 14 additions & 12 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
#include "contrib_ops/cpu/aten_ops/aten_op_executor.h"
#endif

#ifdef USE_CUDA
#include <cuda.h> // for CUDA_VERSION
#include <cudnn.h> // for CUDNN_MAJOR
#endif

#include <pybind11/functional.h>

// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
Expand Down Expand Up @@ -946,26 +951,23 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
provider_options_map);

// This variable is never initialized because the APIs by which it should be initialized are deprecated,
// however they still exist are are in-use. Neverthless, it is used to return CUDAAllocator,
// however they still exist are are in-use. Nevertheless, it is used to return CUDAAllocator,
// hence we must try to initialize it here if we can since FromProviderOptions might contain
// external CUDA allocator.
external_allocator_info = info.external_allocator_info;
return cuda_provider_info->CreateExecutionProviderFactory(info)->CreateProvider();
} else {
if (!Env::Default().GetEnvironmentVar("CUDA_PATH").empty()) {
ORT_THROW(
"CUDA_PATH is set but CUDA wasnt able to be loaded. Please install the correct version of CUDA and"
"cuDNN as mentioned in the GPU requirements page "
" (https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements), "
" make sure they're in the PATH, and that your GPU is supported.");
}
}
}
LOGS_DEFAULT(WARNING) << "Failed to create "
<< type
<< ". Please reference "
<< "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements"
<< "to ensure all dependencies are met.";
<< ". Require cuDNN " << CUDNN_MAJOR << ".* and "
<< "CUDA " << (CUDA_VERSION / 1000) << ".*"
#if defined(_MSC_VER)
<< ", and the latest MSVC runtime"
#endif
<< ". Please install all dependencies as mentioned in the GPU requirements page"
" (https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements), "

Check warning on line 969 in onnxruntime/python/onnxruntime_pybind_state.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_state.cc:969: Lines should be <= 120 characters long [whitespace/line_length] [2]
"make sure they're in the PATH, and that your GPU is supported.";
#endif
} else if (type == kRocmExecutionProvider) {
#ifdef USE_ROCM
Expand Down

0 comments on commit aef98d3

Please sign in to comment.