Skip to content

Commit

Permalink
[CUDA] support user_compute_stream in python API (#19229)
Browse files Browse the repository at this point in the history
### Description
It is an important feature to pass user cuda stream to avoid
synchronization in python API. Here we allow user to pass cuda stream
for CUDA provider. Note that TRT or ROCm provider need similar change,
which are not included in this pull request.

Note that we will set `has_user_compute_stream` automatically based on
whether there is cuda stream passed, so setting
`has_user_compute_stream` through python API has no effect.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

#19094
  • Loading branch information
tianleiwu authored Jan 26, 2024
1 parent 7d4dc66 commit d7ff81d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
16 changes: 16 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace cuda {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
constexpr const char* kUserComputeStream = "user_compute_stream";
constexpr const char* kMemLimit = "gpu_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search";
Expand Down Expand Up @@ -51,6 +52,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
void* alloc = nullptr;
void* free = nullptr;
void* empty_cache = nullptr;
void* user_compute_stream = nullptr;
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
Expand All @@ -66,6 +68,14 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
return Status::OK();
})
.AddAssignmentToReference(cuda::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream)
.AddValueParser(
cuda::provider_option_names::kUserComputeStream,
[&user_compute_stream](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
user_compute_stream = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
cuda::provider_option_names::kGpuExternalAlloc,
[&alloc](const std::string& value_str) -> Status {
Expand Down Expand Up @@ -126,13 +136,18 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P

CUDAExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
info.external_allocator_info = alloc_info;

info.user_compute_stream = user_compute_stream;
info.has_user_compute_stream = (user_compute_stream != nullptr);

return info;
}

ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecutionProviderInfo& info) {
const ProviderOptions options{
{cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
{cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
{cuda::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{cuda::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
Expand Down Expand Up @@ -160,6 +175,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid
const ProviderOptions options{
{cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
{cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
{cuda::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)},
Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,25 @@ def test_get_and_set_option_with_values(option_name, option_values):
self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_alloc"], "0")
self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_free"], "0")
self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_empty_cache"], "0")

option["user_compute_stream"] = "0"
sess.set_providers(["CUDAExecutionProvider"], [option])
options = sess.get_provider_options()
self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], "0")

try:
import torch

if torch.cuda.is_available():
s = torch.cuda.Stream()
option["user_compute_stream"] = str(s.cuda_stream)
sess.set_providers(["CUDAExecutionProvider"], [option])
options = sess.get_provider_options()
self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], str(s.cuda_stream))
self.assertEqual(options["CUDAExecutionProvider"]["has_user_compute_stream"], "1")
except ImportError:
print("torch is not installed, skip testing setting user_compute_stream from torch cuda stream")

#
# Note: Tests that throw an exception leave an empty session due to how set_providers currently works,
# so run them last. Each set_providers call will attempt to re-create a session, so it's
Expand Down

0 comments on commit d7ff81d

Please sign in to comment.