From b9c935f6050b3a57e23dbb79e739489f25f6924a Mon Sep 17 00:00:00 2001 From: mindest <30493312+mindest@users.noreply.github.com> Date: Fri, 24 Nov 2023 17:22:00 +0800 Subject: [PATCH] [ROCm] Some fixes in tunable (#18575) ### Description * Fix workspace size for hipBLASLt algos at 32M * Update according to API changes --- .../contrib_ops/rocm/diffusion/group_norm_triton.cuh | 2 +- onnxruntime/core/providers/rocm/math/softmax_triton.cuh | 2 +- onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h | 7 +++++++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index 526d220d4be24..b7b9441ac997d 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -77,7 +77,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { params->epsilon}; // Grid dim is (batch_count, groups, 1) - return LaunchTritonKernel(params->stream, i, params->n, params->groups, 1, &args, sizeof(args)); + return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); }; ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); } diff --git a/onnxruntime/core/providers/rocm/math/softmax_triton.cuh b/onnxruntime/core/providers/rocm/math/softmax_triton.cuh index 737e396855e35..cc0e0d70056cc 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_triton.cuh +++ b/onnxruntime/core/providers/rocm/math/softmax_triton.cuh @@ -60,7 +60,7 @@ auto GetSoftmaxTritonOps() { } args = {(void*)params->output, (const void*)params->input, params->input_stride, params->output_stride, params->softmax_elements}; // grid dim is (batch_count, 1, 1) - return LaunchTritonKernel(params->stream, i, params->batch_count, 1, 1, &args, sizeof(args)); + return LaunchTritonKernel(params->StreamHandle(), i, params->batch_count, 1, 1, &args, sizeof(args)); }; ret.emplace_back(std::make_pair(metadata->name, std::move(impl))); } diff --git a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h index b9c0cdcc1c341..776dabd757af4 100644 --- a/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h +++ b/onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h @@ -26,6 +26,10 @@ using onnxruntime::contrib::rocm::blas::GemmFastGeluParams; #ifdef USE_HIPBLASLT +// For large K and small M/N, K dim will be split to multiple workgroups and buffers, +// which will require additional workspace. Here we set the max workspace size to 32MB. +constexpr const size_t kHipBlasLtMaxWorkSpaceSizeInBytes = 32 * 1024 * 1024; + enum ActivationType { NONE = 0, RELU = 1, @@ -225,6 +229,9 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp IAllocatorUniquePtr workspace_buffer; if (workspace_size > 0) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(workspace_size > kHipBlasLtMaxWorkSpaceSizeInBytes, + "Workspace size exceeds limit (32M): ", workspace_size); + workspace_size = kHipBlasLtMaxWorkSpaceSizeInBytes; workspace_buffer = params->tuning_ctx->GetScratchBuffer(workspace_size, params->stream); }