Skip to content

Commit

Permalink
Add conditional check in Get/Set current GPU device id (#20932)
Browse files Browse the repository at this point in the history
### Description

Add conditional check in Get/Set current GPU device id


### Motivation and Context

Currently with ROCm build, calling `GetCurrentGpuDeviceId` will still
try to find CUDA libraries and log the following error message:

```text
[E:onnxruntime:, provider_bridge_ort.cc:1836 TryGetProviderInfo_CUDA] /onnxruntime_src/onnxruntime/core/session/provider_bridge_ort.cc:1511 onnxruntime::Provider& onnxruntime::ProviderLibrary::Get() [ONNXRuntimeError] : 1 : FAIL : Failed to load library libonnxruntime_providers_cuda.so with error: libonnxruntime_providers_cuda.so: cannot open shared object file: No such file or directory
```

This is unnecessary and confusing.
  • Loading branch information
skyline75489 authored Jun 6, 2024
1 parent 3ecf48e commit 5b87544
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2099,22 +2099,36 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessi
return OrtApis::SessionOptionsAppendExecutionProvider_CUDA(options, &provider_options);
}

ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, _In_ int device_id) {
ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, [[maybe_unused]] _In_ int device_id) {
API_IMPL_BEGIN

#ifdef USE_CUDA
if (auto* info = onnxruntime::TryGetProviderInfo_CUDA())
return info->SetCurrentGpuDeviceId(device_id);
#endif

#ifdef USE_ROCM
if (auto* info = onnxruntime::TryGetProviderInfo_ROCM())
return info->SetCurrentGpuDeviceId(device_id);
#endif

return CreateStatus(ORT_FAIL, "CUDA and/or ROCM execution provider is either not enabled or not available.");
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, _In_ int* device_id) {
ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, [[maybe_unused]] _In_ int* device_id) {
API_IMPL_BEGIN

#ifdef USE_CUDA
if (auto* info = onnxruntime::TryGetProviderInfo_CUDA())
return info->GetCurrentGpuDeviceId(device_id);
#endif

#ifdef USE_ROCM
if (auto* info = onnxruntime::TryGetProviderInfo_ROCM())
return info->GetCurrentGpuDeviceId(device_id);
#endif

return CreateStatus(ORT_FAIL, "CUDA and/or ROCM execution provider is either not enabled or not available.");
API_IMPL_END
}
Expand Down

0 comments on commit 5b87544

Please sign in to comment.