From 5d09ec91c0283d8a70212e511ecda78461e6b78f Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Fri, 14 Jun 2024 06:56:16 +0000 Subject: [PATCH] add gpu arch checking warning log --- .../rocm/bert/group_query_attention.cu | 21 +++++++++++++++++++ .../rocm/bert/group_query_attention.h | 4 ++++ 2 files changed, 25 insertions(+) diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 04130b46b94e8..7730b0205b69c 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -140,6 +140,12 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); } +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + template Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); @@ -154,6 +160,21 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { const Tensor* sin_cache = ctx->Input(8); auto& device_prop = GetDeviceProp(); + std::call_once( + arch_checking_, + [](const hipDeviceProp_t& device_prop) { + if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos && + std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) { + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention currently only supports ck_tile fmha backend which only supports " + << "CDNA2 and CDNA3 archs."; + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention running on an unsuppoted GPU may result in " + << "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error."; + } + }, + device_prop); + GroupQueryAttentionParameters parameters; using HipT = typename ToHipType::MappedType; diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h index 4d40b5049a8ee..ce0de1f761aa5 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/providers/rocm/rocm_kernel.h" namespace onnxruntime { @@ -27,6 +28,9 @@ class GroupQueryAttention final : public RocmKernel { bool do_rotary_; bool rotary_interleaved_; float scale_; + + private: + static std::once_flag arch_checking_; }; } // namespace rocm