Skip to content

Commit

Permalink
add gpu arch checking warning log
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Jun 14, 2024
1 parent d15329e commit 5d09ec9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
21 changes: 21 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
}

template <>
std::once_flag GroupQueryAttention<MLFloat16>::arch_checking_{};

template <>
std::once_flag GroupQueryAttention<BFloat16>::arch_checking_{};

template <typename T>
Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* ctx) const {
auto hip_stream = static_cast<hipStream_t>(ctx->GetComputeStream()->GetHandle());
Expand All @@ -154,6 +160,21 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* sin_cache = ctx->Input<Tensor>(8);

auto& device_prop = GetDeviceProp();
std::call_once(
arch_checking_,
[](const hipDeviceProp_t& device_prop) {
if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos &&
std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) {
LOGS_DEFAULT(WARNING)
<< "GroupQueryAttention currently only supports ck_tile fmha backend which only supports "
<< "CDNA2 and CDNA3 archs.";
LOGS_DEFAULT(WARNING)
<< "GroupQueryAttention running on an unsuppoted GPU may result in "
<< "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error.";
}
},
device_prop);

GroupQueryAttentionParameters parameters;
using HipT = typename ToHipType<T>::MappedType;

Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <memory>
#include <mutex>
#include "core/providers/rocm/rocm_kernel.h"

namespace onnxruntime {
Expand All @@ -27,6 +28,9 @@ class GroupQueryAttention final : public RocmKernel {
bool do_rotary_;
bool rotary_interleaved_;
float scale_;

private:
static std::once_flag arch_checking_;
};

} // namespace rocm
Expand Down

0 comments on commit 5d09ec9

Please sign in to comment.