Skip to content

Commit

Permalink
support user_compute_stream for rocm ep (#19619)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
According to the pr #19229 supporting cuda EP use external compute
stream, we add support for rocm EP.

And when we testing this feature with torch, we found torch use stream 0
for the default stream, and `torch.cuda.current_stream()` returns `0`
for current stream, but ort treat `0` or `nullptr` as invalid, and reset
has_user_compute_stream to false. 

Will remove has_user_compute_stream option in the future.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
The motivation for this pr is that we want to use torch.cuda.graph to
capture ort running kernel, which requires torch and ort are running in
the same stream, so we use this API to set ort's working stream.
  • Loading branch information
kailums authored Feb 27, 2024
1 parent 8a71b65 commit 6f56656
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
20 changes: 20 additions & 0 deletions onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace onnxruntime {
namespace rocm {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
constexpr const char* kUserComputeStream = "user_compute_stream";
constexpr const char* kMemLimit = "gpu_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
constexpr const char* kMiopenConvExhaustiveSearch = "miopen_conv_exhaustive_search";
Expand All @@ -38,6 +40,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
void* alloc = nullptr;
void* free = nullptr;
void* empty_cache = nullptr;
void* user_compute_stream = nullptr;
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
Expand All @@ -52,6 +55,15 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
return Status::OK();
})
.AddAssignmentToReference(rocm::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream)
.AddValueParser(
rocm::provider_option_names::kUserComputeStream,
[&user_compute_stream](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
user_compute_stream = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
rocm::provider_option_names::kGpuExternalAlloc,
[&alloc](const std::string& value_str) -> Status {
Expand Down Expand Up @@ -108,12 +120,18 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P

ROCMExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
info.external_allocator_info = alloc_info;

info.user_compute_stream = user_compute_stream;
info.has_user_compute_stream = (user_compute_stream != nullptr);

return info;
}

ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) {
const ProviderOptions options{
{rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
{rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
{rocm::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{rocm::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
Expand All @@ -135,6 +153,8 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const OrtROCMProviderOptions& info) {
const ProviderOptions options{
{rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
{rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
{rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast<onnxruntime::ArenaExtendStrategy>(info.arena_extend_strategy))},
{rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,16 @@ def test_get_and_set_option_with_values(option_name, option_values):

test_get_and_set_option_with_values("enable_hip_graph", ["1", "0"])

# test for user_compute_stream
option = options["ROCMExecutionProvider"]
option["user_compute_stream"] = "1"
sess.set_providers(["ROCMExecutionProvider"], [option])
new_options = sess.get_provider_options()
new_option = new_options["ROCMExecutionProvider"]
self.assertEqual(new_option["user_compute_stream"], "1")
# set user_compute_stream will set has_user_compute_stream to 1 too
self.assertEqual(new_option["has_user_compute_stream"], "1")

run_rocm_options_test()

def test_invalid_set_providers(self):
Expand Down

0 comments on commit 6f56656

Please sign in to comment.