From 545cefd6dbfb5571ac716b5fcec1fc53df9aaf06 Mon Sep 17 00:00:00 2001 From: kailums <109063327+kailums@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:31:03 +0800 Subject: [PATCH] support user_compute_stream for rocm ep (#19619) According to the pr #19229 supporting cuda EP use external compute stream, we add support for rocm EP. And when we testing this feature with torch, we found torch use stream 0 for the default stream, and `torch.cuda.current_stream()` returns `0` for current stream, but ort treat `0` or `nullptr` as invalid, and reset has_user_compute_stream to false. Will remove has_user_compute_stream option in the future. The motivation for this pr is that we want to use torch.cuda.graph to capture ort running kernel, which requires torch and ort are running in the same stream, so we use this API to set ort's working stream. --- .../rocm/rocm_execution_provider_info.cc | 20 +++++++++++++++++++ .../test/python/onnxruntime_test_python.py | 12 +++++++++++ 2 files changed, 32 insertions(+) diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 650635c153640..ca85a01916fd3 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -13,6 +13,8 @@ namespace onnxruntime { namespace rocm { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; +constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; +constexpr const char* kUserComputeStream = "user_compute_stream"; constexpr const char* kMemLimit = "gpu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kMiopenConvExhaustiveSearch = "miopen_conv_exhaustive_search"; @@ -37,6 +39,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P void* alloc = nullptr; void* free = nullptr; void* empty_cache = nullptr; + void* user_compute_stream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -51,6 +54,15 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); return Status::OK(); }) + .AddAssignmentToReference(rocm::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) + .AddValueParser( + rocm::provider_option_names::kUserComputeStream, + [&user_compute_stream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + user_compute_stream = reinterpret_cast(address); + return Status::OK(); + }) .AddValueParser( rocm::provider_option_names::kGpuExternalAlloc, [&alloc](const std::string& value_str) -> Status { @@ -106,12 +118,18 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P ROCMExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; info.external_allocator_info = alloc_info; + + info.user_compute_stream = user_compute_stream; + info.has_user_compute_stream = (user_compute_stream != nullptr); + return info; } ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) { const ProviderOptions options{ {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {rocm::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, {rocm::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, @@ -132,6 +150,8 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const OrtROCMProviderOptions& info) { const ProviderOptions options{ {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.arena_extend_strategy))}, {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index d8628c4288206..8614b29937948 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -528,6 +528,18 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("tunable_op_max_tuning_duration_ms", ["-1", "1"]) + test_get_and_set_option_with_values("enable_hip_graph", ["1", "0"]) + + # test for user_compute_stream + option = options["ROCMExecutionProvider"] + option["user_compute_stream"] = "1" + sess.set_providers(["ROCMExecutionProvider"], [option]) + new_options = sess.get_provider_options() + new_option = new_options["ROCMExecutionProvider"] + self.assertEqual(new_option["user_compute_stream"], "1") + # set user_compute_stream will set has_user_compute_stream to 1 too + self.assertEqual(new_option["has_user_compute_stream"], "1") + run_rocm_options_test() def test_invalid_set_providers(self):