From 6f566562cedff9996e55dbf623b1f0141733d52c Mon Sep 17 00:00:00 2001 From: kailums <109063327+kailums@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:31:03 +0800 Subject: [PATCH] support user_compute_stream for rocm ep (#19619) ### Description According to the pr #19229 supporting cuda EP use external compute stream, we add support for rocm EP. And when we testing this feature with torch, we found torch use stream 0 for the default stream, and `torch.cuda.current_stream()` returns `0` for current stream, but ort treat `0` or `nullptr` as invalid, and reset has_user_compute_stream to false. Will remove has_user_compute_stream option in the future. ### Motivation and Context The motivation for this pr is that we want to use torch.cuda.graph to capture ort running kernel, which requires torch and ort are running in the same stream, so we use this API to set ort's working stream. --- .../rocm/rocm_execution_provider_info.cc | 20 +++++++++++++++++++ .../test/python/onnxruntime_test_python.py | 10 ++++++++++ 2 files changed, 30 insertions(+) diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index b557f92287f2b..3cb826437a54f 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -13,6 +13,8 @@ namespace onnxruntime { namespace rocm { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; +constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; +constexpr const char* kUserComputeStream = "user_compute_stream"; constexpr const char* kMemLimit = "gpu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kMiopenConvExhaustiveSearch = "miopen_conv_exhaustive_search"; @@ -38,6 +40,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P void* alloc = nullptr; void* free = nullptr; void* empty_cache = nullptr; + void* user_compute_stream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -52,6 +55,15 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); return Status::OK(); }) + .AddAssignmentToReference(rocm::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) + .AddValueParser( + rocm::provider_option_names::kUserComputeStream, + [&user_compute_stream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + user_compute_stream = reinterpret_cast(address); + return Status::OK(); + }) .AddValueParser( rocm::provider_option_names::kGpuExternalAlloc, [&alloc](const std::string& value_str) -> Status { @@ -108,12 +120,18 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P ROCMExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; info.external_allocator_info = alloc_info; + + info.user_compute_stream = user_compute_stream; + info.has_user_compute_stream = (user_compute_stream != nullptr); + return info; } ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) { const ProviderOptions options{ {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {rocm::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, {rocm::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, @@ -135,6 +153,8 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const OrtROCMProviderOptions& info) { const ProviderOptions options{ {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.arena_extend_strategy))}, {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 91b6c71e735a8..ab56f3fa0f37f 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -559,6 +559,16 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("enable_hip_graph", ["1", "0"]) + # test for user_compute_stream + option = options["ROCMExecutionProvider"] + option["user_compute_stream"] = "1" + sess.set_providers(["ROCMExecutionProvider"], [option]) + new_options = sess.get_provider_options() + new_option = new_options["ROCMExecutionProvider"] + self.assertEqual(new_option["user_compute_stream"], "1") + # set user_compute_stream will set has_user_compute_stream to 1 too + self.assertEqual(new_option["has_user_compute_stream"], "1") + run_rocm_options_test() def test_invalid_set_providers(self):