From c877adcb5023c075d616f9ad8bc412dc83674183 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 4 Jun 2024 08:23:31 +0000 Subject: [PATCH 01/18] feat: init rocm gqa --- cmake/onnxruntime_rocm_hipify.cmake | 1 - .../rocm/bert/group_query_attention.cu | 393 ++++++++++++++++++ .../rocm/bert/group_query_attention.h | 34 ++ .../contrib_ops/rocm/rocm_contrib_kernels.cc | 4 + .../transformers/test_flash_attn_rocm.py | 81 ++++ 5 files changed, 512 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/group_query_attention.h create mode 100644 onnxruntime/test/python/transformers/test_flash_attn_rocm.py diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 2be68146b5e94..2966a4624a966 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -88,7 +88,6 @@ set(contrib_ops_excluded_files "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" - "bert/group_query_attention_helper.h" "bert/group_query_attention.h" "bert/group_query_attention.cc" "bert/group_query_attention_impl.h" diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu new file mode 100644 index 0000000000000..cc88b66bea390 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -0,0 +1,393 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include "ck_tile/core/numeric/integer.hpp" +#include "core/providers/rocm/rocm_common.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/rocm/bert/group_query_attention.h" +#include "contrib_ops/rocm/bert/group_query_attention_helper.h" +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" + +#ifdef USE_COMPOSABLE_KERNEL_CK_TILE +#include "fmha_fwd.hpp" +#endif + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()) \ + .MayInplace(3, 1) \ + .MayInplace(4, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 6), \ + GroupQueryAttention); + +// REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +// REGISTER_KERNEL_TYPED(BFloat16) + +template +std::string GetCkFmhaDataTypeString(); + +template <> +std::string GetCkFmhaDataTypeString() { + return "fp16"; +} + +template <> +std::string GetCkFmhaDataTypeString() { + return "bf16"; +} + +__global__ void seqlens_inc_kernel(const int* seqlens, int* out, int num_elems, int inc) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = seqlens[idx] + inc; + printf("inc %d: %d\n", idx, out[idx]); + } +} + +Status LaunchSeqlensInc(hipStream_t stream, const int* seqlens, int* out, int num_elems, int inc) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqlens_inc_kernel<<>>(seqlens, out, num_elems, inc); + return HIP_CALL(hipGetLastError()); +} + +__global__ void seqstart_init_kernel(int* out, int num_elems, int length_per_seq) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = idx * length_per_seq; + printf("init %d: %d\n", idx, out[idx]); + } + if (idx == 0) { + out[num_elems] = num_elems * length_per_seq; + } +} + +Status LaunchSeqStartInit(hipStream_t stream, int* out, int num_elems, int length_per_seq) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqstart_init_kernel<<>>(out, num_elems, length_per_seq); + return HIP_CALL(hipGetLastError()); +} + +template +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) + : RocmKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + is_past_bsnh_ = false; + is_unidirectional_ = true; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + scale_ = info.GetAttrOrDefault("scale", 0.0f); +} + +template +Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { + auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); + const Tensor* query = ctx->Input(0); + const Tensor* key = ctx->Input(1); + const Tensor* value = ctx->Input(2); + const Tensor* past_key = ctx->Input(3); + const Tensor* past_value = ctx->Input(4); + const Tensor* seqlens_k = ctx->Input(5); + const Tensor* total_seqlen = ctx->Input(6); + const Tensor* cos_cache = ctx->Input(7); + const Tensor* sin_cache = ctx->Input(8); + + auto& device_prop = GetDeviceProp(); + GroupQueryAttentionParameters parameters; + using HipT = typename ToHipType::MappedType; + + const int max_thr_per_blk = device_prop.maxThreadsPerBlock; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶meters, + num_heads_, + kv_num_heads_, + seqlens_k, + total_seqlen, + is_past_bsnh_, + scale_, + max_thr_per_blk)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + // parameters.zeros_count = kZerosCount; + // parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); + } + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.hidden_size); + Tensor* output = ctx->Output(0, output_shape); + Strides output_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + + int4 past_shape; + std::vector present_dims; + Strides present_strides; + Strides past_strides; + if (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + past_shape = { + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size}; + past_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size); + present_dims = { + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size}; + present_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + } else { // BNSH + past_shape = { + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size}; + past_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size); + present_dims = { + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size}; + present_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size); + } + TensorShape present_shape(present_dims); + Tensor* present_key = ctx->Output(1, present_shape); + Tensor* present_value = ctx->Output(2, present_shape); + + Strides query_strides; + const void* query_ptr = query->DataRaw(); + const HipT* key_ptr; + const HipT* value_ptr; + if (!parameters.is_packed_qkv) { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + key_ptr = reinterpret_cast(key->DataRaw()); + value_ptr = reinterpret_cast(value->DataRaw()); + } else { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key_ptr = reinterpret_cast(query_ptr) + key_offset; + value_ptr = reinterpret_cast(key_ptr) + value_offset; + } + + const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; + IAllocatorUniquePtr seqlens_k_tmp; + + // build present kv cache + auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); + ck_tile::index_t shape_seqlen_q = batch_size; + ck_tile::index_t shape_seqlen_k = batch_size; + + auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); + int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; + auto kv_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); + if (parameters.is_prompt) { + // copy prompt kv to present kv + // ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, kv_strides.ForBNSHCoord(), + // present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + // ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, kv_strides.ForBNSHCoord(), + // present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + const auto* past_key_ptr = reinterpret_cast(past_key->DataRaw()); + const auto* past_value_ptr = reinterpret_cast(past_value->DataRaw()); + if (!parameters.kv_share_buffer) { + // copy past to present, + // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are + // not the same, aka, can not be as simple as strided + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_key_ptr, past_shape, past_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + // In the case of share buffer + ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); + ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); + } + // then append new kv to present + size_t buffer_offset = seqlens_k ? 0 : present_strides.OffsetAt(0, 0, kv_sequence_length, 0); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, key_ptr, kv_shape, kv_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_key_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, value_ptr, kv_shape, kv_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_value_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + + // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. + // we should call fmha with total sequence lenghts + seqlens_k_tmp = GetScratchBuffer(shape_seqlen_k * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), shape_seqlen_k, sequence_length)); + seqlens_k_ptr = seqlens_k_tmp.get(); + } + static_assert(std::is_same_v); + + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; + // TODO: + mask_enum mask_type = mask_enum::no_mask; + bias_enum bias_type = bias_enum::no_bias; + mask_info mask = mask_info::decode("0", sequence_length, kv_sequence_length); + + auto seqstart_q_tmp = GetScratchBuffer((shape_seqlen_q + 1) * sizeof(int), ctx->GetComputeStream()); + auto seqstart_k_tmp = GetScratchBuffer((shape_seqlen_k + 1) * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit(hip_stream, seqstart_q_tmp.get(), shape_seqlen_q, + query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit(hip_stream, seqstart_k_tmp.get(), shape_seqlen_k, + present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); + + fmha_fwd_args args{ + query->DataRaw(), + present_key->DataRaw(), + present_value->DataRaw(), + nullptr, // bias, alibi/element + nullptr, // lse, logsumexp buffer + output->MutableDataRaw(), + seqstart_q_tmp.get(), // seqstart_q_ptr + seqstart_k_tmp.get(), // seqstart_k_ptr + seqlens_k_ptr, // seqlen_k_ptr + shape_seqlen_q, // seqlen_q + shape_seqlen_k, // seqlen_k + parameters.batch_size, // batch + parameters.sequence_length, // max_seqlen_q + parameters.head_size, // hdim_q + parameters.head_size, // hdim_v + parameters.num_heads, + parameters.kv_num_heads, + scale, + 1.0f, // scale_p of squant, useless + 1.0f, // scale_o of squant, useless + static_cast(query_strides.strides_for_bnsh_coord.z), // stride_q, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_k, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_v, to be regarded as stride of dim S + shape_seqlen_k, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 + static_cast(output_strides.strides_for_bnsh_coord.z), // stride_o, to be regarded as stride of dim S + static_cast(query_strides.strides_for_bnsh_coord.y), // nhead_stride_q, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_k, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_v, to be regarded as stride of dim N + 0, // nhead_stride_bias + shape_seqlen_q, // nhead_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.y), // batch_stride_o, to be regarded as stride of dim B + static_cast(query_strides.strides_for_bnsh_coord.x), // batch_stride_q, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_k, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_v, to be regarded as stride of dim B + 0, // batch_stride_bias + num_heads * shape_seqlen_q, // batch_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.x), // batch_stride_o, to be regarded as stride of dim B + mask.left, // window_size_left + mask.right, // window_size_right + static_cast(mask.type)}; + + std::cout << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache + << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; + + std::cout + << "\n q_ptr:" << args.q_ptr + << "\n k_ptr:" << args.k_ptr + << "\n v_ptr:" << args.v_ptr + << "\n bias_ptr:" << args.bias_ptr + << "\n lse_ptr:" << args.lse_ptr + << "\n o_ptr:" << args.o_ptr + << "\n seqstart_q_ptr:" << args.seqstart_q_ptr + << "\n seqstart_k_ptr:" << args.seqstart_k_ptr + << "\n seqlen_k_ptr:" << args.seqlen_k_ptr + << "\n seqlen_q:" << args.seqlen_q + << "\n seqlen_k:" << args.seqlen_k + << "\n batch:" << args.batch + << "\n max_seqlen_q:" << args.max_seqlen_q + << "\n hdim_q:" << args.hdim_q + << "\n hdim_v:" << args.hdim_v + << "\n nhead_q:" << args.nhead_q + << "\n nhead_k:" << args.nhead_k + << "\n scale_s:" << args.scale_s + << "\n scale_p:" << args.scale_p + << "\n scale_o:" << args.scale_o + << "\n stride_q:" << args.stride_q + << "\n stride_k:" << args.stride_k + << "\n stride_v:" << args.stride_v + << "\n stride_bias:" << args.stride_bias + << "\n stride_o:" << args.stride_o + << "\n nhead_stride_q:" << args.nhead_stride_q + << "\n nhead_stride_k:" << args.nhead_stride_k + << "\n nhead_stride_v:" << args.nhead_stride_v + << "\n nhead_stride_bias:" << args.nhead_stride_bias + << "\n nhead_stride_lse:" << args.nhead_stride_lse + << "\n nhead_stride_o:" << args.nhead_stride_o + << "\n batch_stride_q:" << args.batch_stride_q + << "\n batch_stride_k:" << args.batch_stride_k + << "\n batch_stride_v:" << args.batch_stride_v + << "\n batch_stride_bias:" << args.batch_stride_bias + << "\n batch_stride_lse:" << args.batch_stride_lse + << "\n batch_stride_o:" << args.batch_stride_o + << "\n window_size_left:" << args.window_size_left + << "\n window_size_right:" << args.window_size_right + << "\n mask_type:" << args.mask_type + << std::endl; + + fmha_fwd_traits traits{ + parameters.head_size, + parameters.head_size, // v head size + GetCkFmhaDataTypeString(), + true, // is_group_mode + true, // is_v_rowmajor ? dim is fastest : seq is fastest + mask_type, + bias_type, + false, // has_lse + false, // do_fp8_static_quant, aka, squant + }; + + ck_tile::stream_config stream_config{ + hip_stream, + false // time_kernel + }; + + auto duration = fmha_fwd(traits, args, stream_config); + if (duration < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "fmha_fwd internal error"); + } + HIP_RETURN_IF_ERROR(hipGetLastError()); + + return Status::OK(); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h new file mode 100644 index 0000000000000..4d40b5049a8ee --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class GroupQueryAttention final : public RocmKernel { + public: + GroupQueryAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int local_window_size_; + bool is_unidirectional_; + bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; + float scale_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 7e5e7d7ee076d..4284b4254f485 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -71,6 +71,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); @@ -227,6 +229,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py new file mode 100644 index 0000000000000..628372300aa65 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -0,0 +1,81 @@ +import platform +import unittest + +import torch +from parameterized import parameterized +from test_flash_attn_cuda import ( + Formats, + gqa_no_past_flash_attention_test_cases, + gqa_past_flash_attention_test_cases, + parity_check_gqa_past, + parity_check_gqa_past_no_buff, + parity_check_gqa_prompt, + parity_check_gqa_prompt_no_buff, +) + + +class TestGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_flash_attention_test_cases()) + def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + config.ep = "ROCMExecutionProvider" + if not torch.cuda.is_available(): + return + if platform.system() != "Linux": + return + print("------- FLASH ATTENTION (PROMPT CASE) --------") + + parity_check_gqa_prompt( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + + @parameterized.expand(gqa_past_flash_attention_test_cases()) + def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + config.ep = "ROCMExecutionProvider" + if not torch.cuda.is_available(): + return + config.ep = "ROCMExecutionProvider" + if platform.system() != "Linux": + return + print("------- FLASH ATTENTION (TOKEN GEN) -------") + + parity_check_gqa_past( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + + +if __name__ == "__main__": + unittest.main() From f845099da4a8af5988414ee02896d5e0898afe61 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Thu, 6 Jun 2024 10:59:28 +0000 Subject: [PATCH 02/18] feat: extend strided copy to support runtime tok idx --- .../contrib_ops/cuda/bert/attention_impl.h | 6 ++ .../cuda/bert/attention_strided_copy.cu | 58 ++++++++++++++----- .../contrib_ops/rocm/bert/attention_impl.h | 6 ++ 3 files changed, 54 insertions(+), 16 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 36fd7708de04b..2269ed143e929 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -176,6 +176,12 @@ Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, const T* qkv_buffer, T* present); +template +Status LaunchStridedCopy(cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); + template Status LaunchStridedCopy(cudaStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu index 1466f5fcfe0be..2530398504ac6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu @@ -12,23 +12,27 @@ namespace cuda { template __global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides // coord (b,n,s,h) -) { + T* out, longlong4 out_strides, // coord (b,n,s,h) + const int32_t* in_seqlens_offset, const int32_t* out_seqlens_offset) { const int h = threadIdx.x; const int n = threadIdx.y; const int s = blockIdx.x; const int b = blockIdx.y; + + const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b]; + const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b]; + if (h < H) { - const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w; - const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w; + const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w; + const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w; out[out_offset] = in[in_offset]; } } template __global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides // coord (b,n,s,h) -) { + T* out, longlong4 out_strides, // coord (b,n,s,h) + const int* in_seqlens_offset, const int* out_seqlens_offset) { // Use when (H*)*num_heads > 1024 int h = threadIdx.x; const int n = threadIdx.y; @@ -37,9 +41,12 @@ __global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, const int h_step = blockDim.x; + const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b]; + const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b]; + while (h < H) { - const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w; - const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w; + const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w; + const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w; out[out_offset] = in[in_offset]; h += h_step; } @@ -78,8 +85,8 @@ using ToBytes = typename ToByteType::T; template Status LaunchStridedCopy(cudaStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides, // coord (b,n,s,h) + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) int max_threads_per_block) { int batch_size = in_shape.x; int num_heads = in_shape.y; @@ -102,11 +109,13 @@ Status LaunchStridedCopy(cudaStream_t stream, if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } else if (0 == (head_size % 2)) { // pack 2 element together using Bytes = ToBytes; @@ -120,27 +129,44 @@ Status LaunchStridedCopy(cudaStream_t stream, if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } else { using Bytes = ToBytes; if (head_size * num_heads <= max_threads_per_block) { const dim3 block(head_size, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), head_size, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), head_size, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } return CUDA_CALL(cudaGetLastError()); } +template +Status LaunchStridedCopy(cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) + T* out, longlong4 out_strides, // coord (b,n,s,h) + int max_threads_per_block) { + const int* in_seqlens_offset = nullptr; + const int* out_seqlens_offset = nullptr; + return LaunchStridedCopy( + stream, in, in_shape, in_strides, in_seqlens_offset, + out, out_strides, out_seqlens_offset, + max_threads_per_block); +} + template Status LaunchStridedCopy( cudaStream_t stream, const float* in, int4 in_shape, longlong4 in_strides, diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 3164e8c211099..5e333c60d5968 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -169,6 +169,12 @@ Status ClassifyAttentionMode(AttentionType type, const std::vector& past, const std::vector& present); +template +Status LaunchStridedCopy(hipStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); + template Status LaunchStridedCopy(hipStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) From 0ea33352b7a347227ab7f232a205bc7e52cc7001 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Fri, 7 Jun 2024 08:24:59 +0000 Subject: [PATCH 03/18] more case --- .../rocm/bert/group_query_attention.cu | 60 +++++++++++-------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index cc88b66bea390..c707c78fe5e86 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -214,18 +214,15 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { // build present kv cache auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); - ck_tile::index_t shape_seqlen_q = batch_size; - ck_tile::index_t shape_seqlen_k = batch_size; - auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; auto kv_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); if (parameters.is_prompt) { // copy prompt kv to present kv - // ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, kv_strides.ForBNSHCoord(), - // present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - // ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, kv_strides.ForBNSHCoord(), - // present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, kv_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, kv_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); } else { const auto* past_key_ptr = reinterpret_cast(past_key->DataRaw()); const auto* past_value_ptr = reinterpret_cast(past_value->DataRaw()); @@ -255,23 +252,36 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. // we should call fmha with total sequence lenghts - seqlens_k_tmp = GetScratchBuffer(shape_seqlen_k * sizeof(int), ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), shape_seqlen_k, sequence_length)); + seqlens_k_tmp = GetScratchBuffer(batch_size * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), batch_size, sequence_length)); seqlens_k_ptr = seqlens_k_tmp.get(); } static_assert(std::is_same_v); const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; // TODO: - mask_enum mask_type = mask_enum::no_mask; bias_enum bias_type = bias_enum::no_bias; - mask_info mask = mask_info::decode("0", sequence_length, kv_sequence_length); - auto seqstart_q_tmp = GetScratchBuffer((shape_seqlen_q + 1) * sizeof(int), ctx->GetComputeStream()); - auto seqstart_k_tmp = GetScratchBuffer((shape_seqlen_k + 1) * sizeof(int), ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit(hip_stream, seqstart_q_tmp.get(), shape_seqlen_q, + mask_enum mask_type; + mask_info mask; + if(local_window_size_ != -1) { + ORT_NOT_IMPLEMENTED("local_window_size support is not implemented"); + } + if (parameters.is_prompt) { + if (is_unidirectional_) { + mask_type = mask_enum::mask_top_left; + mask = mask_info::decode("t", sequence_length, kv_sequence_length); + } + } else { + mask_type = mask_enum::no_mask; + mask = mask_info::decode("0", sequence_length, kv_sequence_length); + } + + auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit(hip_stream, seqstart_q_tmp.get(), batch_size, query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit(hip_stream, seqstart_k_tmp.get(), shape_seqlen_k, + ORT_RETURN_IF_ERROR(LaunchSeqStartInit(hip_stream, seqstart_k_tmp.get(), batch_size, present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); fmha_fwd_args args{ @@ -281,11 +291,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { nullptr, // bias, alibi/element nullptr, // lse, logsumexp buffer output->MutableDataRaw(), - seqstart_q_tmp.get(), // seqstart_q_ptr - seqstart_k_tmp.get(), // seqstart_k_ptr - seqlens_k_ptr, // seqlen_k_ptr - shape_seqlen_q, // seqlen_q - shape_seqlen_k, // seqlen_k + seqstart_q_tmp.get(), // seqstart_q_ptr, for group mode + seqstart_k_tmp.get(), // seqstart_k_ptr, for group mode + seqlens_k_ptr, // seqlen_k_ptr, for group mode + sequence_length, // seqlen_q, for batch mode + kv_sequence_length, // seqlen_k, for batch mode parameters.batch_size, // batch parameters.sequence_length, // max_seqlen_q parameters.head_size, // hdim_q @@ -298,19 +308,19 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { static_cast(query_strides.strides_for_bnsh_coord.z), // stride_q, to be regarded as stride of dim S static_cast(present_strides.strides_for_bnsh_coord.z), // stride_k, to be regarded as stride of dim S static_cast(present_strides.strides_for_bnsh_coord.z), // stride_v, to be regarded as stride of dim S - shape_seqlen_k, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 + batch_size, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 static_cast(output_strides.strides_for_bnsh_coord.z), // stride_o, to be regarded as stride of dim S static_cast(query_strides.strides_for_bnsh_coord.y), // nhead_stride_q, to be regarded as stride of dim N static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_k, to be regarded as stride of dim N static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_v, to be regarded as stride of dim N 0, // nhead_stride_bias - shape_seqlen_q, // nhead_stride_lse + batch_size, // nhead_stride_lse static_cast(output_strides.strides_for_bnsh_coord.y), // batch_stride_o, to be regarded as stride of dim B static_cast(query_strides.strides_for_bnsh_coord.x), // batch_stride_q, to be regarded as stride of dim B static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_k, to be regarded as stride of dim B static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_v, to be regarded as stride of dim B 0, // batch_stride_bias - num_heads * shape_seqlen_q, // batch_stride_lse + num_heads * batch_size, // batch_stride_lse static_cast(output_strides.strides_for_bnsh_coord.x), // batch_stride_o, to be regarded as stride of dim B mask.left, // window_size_left mask.right, // window_size_right @@ -366,8 +376,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.head_size, parameters.head_size, // v head size GetCkFmhaDataTypeString(), - true, // is_group_mode - true, // is_v_rowmajor ? dim is fastest : seq is fastest + !parameters.is_prompt, // true, // is_group_mode + true, // is_v_rowmajor ? dim is fastest : seq is fastest mask_type, bias_type, false, // has_lse From 99b2feb61888a0c22efb75547a2aece857d481a2 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 11 Jun 2024 15:09:07 +0000 Subject: [PATCH 04/18] feat: local --- .../rocm/bert/group_query_attention.cu | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index c707c78fe5e86..cc11db5cad18a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -262,20 +262,23 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { // TODO: bias_enum bias_type = bias_enum::no_bias; - mask_enum mask_type; - mask_info mask; - if(local_window_size_ != -1) { - ORT_NOT_IMPLEMENTED("local_window_size support is not implemented"); - } - if (parameters.is_prompt) { - if (is_unidirectional_) { - mask_type = mask_enum::mask_top_left; - mask = mask_info::decode("t", sequence_length, kv_sequence_length); + mask_info mask = [&]() { + if (local_window_size_ != -1) { + mask_info ret; + ret.type = mask_enum::window_generic; + ret.left = local_window_size_; + ret.right = parameters.is_unidirectional ? 0 : -1; + // ret.x = kv_sequence_length - (sequence_length - ret.left); + // ret.y = sequence_length + (ret.right - kv_sequence_length); + return ret; } - } else { - mask_type = mask_enum::no_mask; - mask = mask_info::decode("0", sequence_length, kv_sequence_length); - } + + if (parameters.is_prompt && is_unidirectional_) { + return mask_info::decode("t", sequence_length, kv_sequence_length); + } + + return mask_info::decode("0", sequence_length, kv_sequence_length); + }(); auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); @@ -378,7 +381,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { GetCkFmhaDataTypeString(), !parameters.is_prompt, // true, // is_group_mode true, // is_v_rowmajor ? dim is fastest : seq is fastest - mask_type, + mask.type, bias_type, false, // has_lse false, // do_fp8_static_quant, aka, squant From 816249c83a151ef68774a67d48b71057d868c874 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Wed, 12 Jun 2024 05:46:43 +0000 Subject: [PATCH 05/18] feat: rotary --- .../rocm/bert/group_query_attention.cu | 45 +++++++++++++++++-- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index cc11db5cad18a..84cecb17e4c62 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -194,7 +194,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { Tensor* present_value = ctx->Output(2, present_shape); Strides query_strides; - const void* query_ptr = query->DataRaw(); + const HipT* query_ptr = reinterpret_cast(query->DataRaw()); const HipT* key_ptr; const HipT* value_ptr; if (!parameters.is_packed_qkv) { @@ -205,8 +205,45 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); const size_t key_offset = static_cast(num_heads * head_size); const size_t value_offset = static_cast(kv_num_heads * head_size); - key_ptr = reinterpret_cast(query_ptr) + key_offset; - value_ptr = reinterpret_cast(key_ptr) + value_offset; + key_ptr = query_ptr + key_offset; + value_ptr = key_ptr + value_offset; + } + + IAllocatorUniquePtr rotary_q_tmp; + IAllocatorUniquePtr rotary_k_tmp; + if (parameters.do_rotary) { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + // auto q_buffer = reinterpret_cast(data.rotary_buffer); + // auto k_buffer = q_buffer + q_size; + rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); + rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); + auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, + reinterpret_cast(seqlens_k->DataRaw()), + reinterpret_cast(rotary_position_ids_tmp.get()), + hip_stream, max_thr_per_blk)); + // Launch rotary embedding kernel + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_q_tmp.get(), query_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, /*transposed*/ false)); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.kv_num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, /*transposed*/ false)); + query_ptr = reinterpret_cast(rotary_q_tmp.get()); + key_ptr = reinterpret_cast(rotary_k_tmp.get()); } const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; @@ -288,7 +325,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); fmha_fwd_args args{ - query->DataRaw(), + query_ptr, present_key->DataRaw(), present_value->DataRaw(), nullptr, // bias, alibi/element From 6024dc960ad95c23c76196fec1fa3515f565d382 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Thu, 13 Jun 2024 01:23:57 +0000 Subject: [PATCH 06/18] feat: allow rotary to read and write in a strided way, so that we don't need to explicit unpack the packed qkv tensor --- .../cuda/bert/rotary_embedding_impl.cu | 59 +++++++++++++------ .../cuda/bert/rotary_embedding_impl.h | 21 +++++++ 2 files changed, 62 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index 1b28b288f3d7c..c7fa071c27daf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -25,8 +25,9 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int64_t* position_ids, // (1) or BxS const int sequence_length, const int num_heads, const int head_size, const int rotary_embedding_dim, const int position_ids_format, - const bool interleaved, const int batch_stride, const int seq_stride, - const int head_stride) { + const bool interleaved, + int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous +) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently @@ -40,10 +41,8 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH return; } - const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - - const T* input_data = input + block_offset; - T* output_data = output + block_offset; + const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; + T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; if (i >= rotary_embedding_dim) { output_data[i] = input_data[i]; @@ -77,34 +76,58 @@ template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, const T* cos_cache, const T* sin_cache, const int batch_size, const int sequence_length, const int num_heads, const int head_size, - const int rotary_embedding_dim, const int /*max_sequence_length*/, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool is_input_bnsh_format) { + int4 in_strides; + int4 out_strides; + if (is_input_bnsh_format) { + int in_head_stride = sequence_length * head_size; + int out_head_stride = sequence_length * head_size; + in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1}; + out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1}; + } else { + int in_head_stride = head_size; + int out_head_stride = head_size; + in_strides = int4{sequence_length * num_heads * in_head_stride, in_head_stride, num_heads * in_head_stride, 1}; + out_strides = int4{sequence_length * num_heads * out_head_stride, out_head_stride, num_heads * out_head_stride, 1}; + } + return LaunchRotaryEmbeddingKernel( + stream, output, input, position_ids, + cos_cache, sin_cache, batch_size, + sequence_length, num_heads, head_size, + rotary_embedding_dim, max_sequence_length, + position_ids_format, interleaved, + max_threads_per_block, is_input_bnsh_format, + in_strides, out_strides); +} + +template +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, + const T* cos_cache, const T* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int /*max_sequence_length*/, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, const bool is_input_bnsh_format, + int4 in_strides, int4 out_strides // strides in bnsh coord +) { // Note: Current implementation assumes head_size <= max_threads_per_block // because head_size is currently large for LLaMA-2. For smaller head_size // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); + // strides in cannoical bnsh coord, h is always contiguous (dim_stride == 1) + ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous"); int tpb = (head_size + 31) / 32 * 32; const dim3 block(tpb); const dim3 grid(sequence_length, batch_size, num_heads); - // Default input tensor shape is [batch, seq, hidden_size] - int head_stride = head_size; - int seq_stride = num_heads * head_stride; - int batch_stride = sequence_length * seq_stride; - if (is_input_bnsh_format) { - seq_stride = head_size; - head_stride = sequence_length * seq_stride; - batch_stride = num_heads * head_stride; - } - assert(head_size <= max_threads_per_block); RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size, rotary_embedding_dim, position_ids_format, - interleaved, batch_stride, seq_stride, head_stride); + interleaved, in_strides, out_strides); return CUDA_CALL(cudaGetLastError()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index 6053814b835bb..c52b6d18141f2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -28,6 +28,27 @@ Status LaunchRotaryEmbeddingKernel( const int max_threads_per_block, const bool is_input_bnsh_format); +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + const bool is_input_bnsh_format, + int4 in_strides, + int4 out_strides); + } // namespace cuda } // namespace contrib } // namespace onnxruntime From 48092eeeedcc8782bc1c807b38b22f0db9809d9b Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Thu, 13 Jun 2024 01:27:57 +0000 Subject: [PATCH 07/18] fix: rotary for packed qkv --- .../cuda/bert/group_query_attention_impl.cu | 6 +- .../rocm/bert/group_query_attention.cu | 113 +++++++++++++++--- 2 files changed, 100 insertions(+), 19 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 62974d12003fe..77c85afffb66c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -577,7 +577,7 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp } // Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen, +__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; int b = tid / seqlen; @@ -592,7 +592,7 @@ __global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, } // Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { +__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid < batch_size) { position_ids[tid] = seqlens_k[tid]; @@ -600,7 +600,7 @@ __global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, } // Convert seqlens_k to position_ids -Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { const int seqlen = parameters.sequence_length; const int batch_size = parameters.batch_size; diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 84cecb17e4c62..f0a9cebb0dcc4 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -7,6 +7,7 @@ #include "core/platform/env_var_utils.h" #include "contrib_ops/rocm/bert/group_query_attention.h" #include "contrib_ops/rocm/bert/group_query_attention_helper.h" +#include "contrib_ops/rocm/bert/rotary_embedding_impl.h" #include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" #ifdef USE_COMPOSABLE_KERNEL_CK_TILE @@ -21,6 +22,18 @@ namespace onnxruntime { namespace contrib { namespace rocm { +void print(const std::string& msg, const longlong4& s) { + std::cout << msg << ":" << s.x << "," << s.y << "," << s.z << "," << s.w << std::endl; +} + +void print(const std::string& msg, const int4& s) { + std::cout << msg << ":" << s.x << "," << s.y << "," << s.z << "," << s.w << std::endl; +} + +void print(const std::string& msg, const Strides& s) { + print(msg, s.strides_for_bnsh_coord); +} + #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GroupQueryAttention, \ @@ -57,7 +70,6 @@ __global__ void seqlens_inc_kernel(const int* seqlens, int* out, int num_elems, int idx = blockDim.x * blockIdx.x + threadIdx.x; if (idx < num_elems) { out[idx] = seqlens[idx] + inc; - printf("inc %d: %d\n", idx, out[idx]); } } @@ -72,7 +84,6 @@ __global__ void seqstart_init_kernel(int* out, int num_elems, int length_per_seq int idx = blockDim.x * blockIdx.x + threadIdx.x; if (idx < num_elems) { out[idx] = idx * length_per_seq; - printf("init %d: %d\n", idx, out[idx]); } if (idx == 0) { out[num_elems] = num_elems * length_per_seq; @@ -86,6 +97,44 @@ Status LaunchSeqStartInit(hipStream_t stream, int* out, int num_elems, int lengt return HIP_CALL(hipGetLastError()); } +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, + const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + if (s < seqlens_k[b] + 1) { + position_ids[tid] = s; + } else { + position_ids[tid] = 1; + } + } +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + position_ids[tid] = seqlens_k[tid]; + } +} + +// Convert seqlens_k to position_ids +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, + int64_t* position_ids, hipStream_t stream, const int max_threads_per_block) { + const int seqlen = parameters.sequence_length; + const int batch_size = parameters.batch_size; + const int threads = max_threads_per_block; + const int blocks = (batch_size * seqlen + threads - 1) / threads; + if (parameters.is_prompt) { + SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else { + SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); + } + return HIP_CALL(hipGetLastError()); +} + template GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) : RocmKernel(info) { @@ -194,19 +243,29 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { Tensor* present_value = ctx->Output(2, present_shape); Strides query_strides; + Strides key_strides; + Strides value_strides; + int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; // BNSH coord const HipT* query_ptr = reinterpret_cast(query->DataRaw()); const HipT* key_ptr; const HipT* value_ptr; if (!parameters.is_packed_qkv) { query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); + value_strides = key_strides; key_ptr = reinterpret_cast(key->DataRaw()); value_ptr = reinterpret_cast(value->DataRaw()); } else { query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + value_strides = query_strides; const size_t key_offset = static_cast(num_heads * head_size); const size_t value_offset = static_cast(kv_num_heads * head_size); key_ptr = query_ptr + key_offset; value_ptr = key_ptr + value_offset; + + print("!is_packed_qkv, key_shape ", kv_shape); + print("!is_packed_qkv, key_strides", key_strides); } IAllocatorUniquePtr rotary_q_tmp; @@ -214,8 +273,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { if (parameters.do_rotary) { size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); - // auto q_buffer = reinterpret_cast(data.rotary_buffer); - // auto k_buffer = q_buffer + q_size; + auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size); + + print("do_rotary, rotary_k_shape ", int4{batch_size, sequence_length, kv_num_heads, head_size}); + print("do_rotary, rotary_k_strides", rotary_k_strides); + rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); @@ -232,7 +295,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.num_heads, parameters.head_size, parameters.rotary_dim, parameters.seqlen_present_kv_cache, /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, /*transposed*/ false)); + max_thr_per_blk, /*transposed*/ false, + query_strides.ForBNSHCoord(), + rotary_q_strides.ForBNSHCoord())); ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, reinterpret_cast(rotary_position_ids_tmp.get()), reinterpret_cast(cos_cache->DataRaw()), @@ -241,9 +306,15 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.kv_num_heads, parameters.head_size, parameters.rotary_dim, parameters.seqlen_present_kv_cache, /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, /*transposed*/ false)); + max_thr_per_blk, /*transposed*/ false, + key_strides.ForBNSHCoord(), + rotary_k_strides.ForBNSHCoord())); query_ptr = reinterpret_cast(rotary_q_tmp.get()); key_ptr = reinterpret_cast(rotary_k_tmp.get()); + query_strides = rotary_q_strides; + key_strides = rotary_k_strides; + print("do_rotary, key_shape ", kv_shape); + print("do_rotary, key_strides", key_strides); } const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; @@ -252,18 +323,22 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { // build present kv cache auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); - int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; - auto kv_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); if (parameters.is_prompt) { + std::cout << "is_prompt" << std::endl; // copy prompt kv to present kv - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, kv_strides.ForBNSHCoord(), + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, kv_strides.ForBNSHCoord(), + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + print("is_prompt, key_shape ", kv_shape); + print("is_prompt, key_strides", key_strides); } else { - const auto* past_key_ptr = reinterpret_cast(past_key->DataRaw()); - const auto* past_value_ptr = reinterpret_cast(past_value->DataRaw()); + std::cout << "!is_prompt" << std::endl; + const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_key->DataRaw()); + const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_value->DataRaw()); + parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME: if (!parameters.kv_share_buffer) { + std::cout << "!kv_share_buffer" << std::endl; // copy past to present, // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are // not the same, aka, can not be as simple as strided @@ -272,6 +347,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); } else { + std::cout << "kv_share_buffer" << std::endl; // In the case of share buffer ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); @@ -279,11 +355,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { // then append new kv to present size_t buffer_offset = seqlens_k ? 0 : present_strides.OffsetAt(0, 0, kv_sequence_length, 0); ORT_RETURN_IF_ERROR(LaunchStridedCopy( - hip_stream, key_ptr, kv_shape, kv_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, present_key_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, max_thr_per_blk)); ORT_RETURN_IF_ERROR(LaunchStridedCopy( - hip_stream, value_ptr, kv_shape, kv_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, present_value_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, max_thr_per_blk)); @@ -366,8 +442,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { mask.right, // window_size_right static_cast(mask.type)}; - std::cout << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache - << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; +#if 0 + std::cout + << "\n sequence_length:" << sequence_length + << "\n kv_sequence_length:" << kv_sequence_length + << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache + << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; std::cout << "\n q_ptr:" << args.q_ptr @@ -411,6 +491,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { << "\n window_size_right:" << args.window_size_right << "\n mask_type:" << args.mask_type << std::endl; +#endif fmha_fwd_traits traits{ parameters.head_size, From de2f30aefdb278a40b079e73016494b938ac0155 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Thu, 13 Jun 2024 01:29:21 +0000 Subject: [PATCH 08/18] remove debug print --- .../rocm/bert/group_query_attention.cu | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index f0a9cebb0dcc4..04130b46b94e8 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -22,18 +22,6 @@ namespace onnxruntime { namespace contrib { namespace rocm { -void print(const std::string& msg, const longlong4& s) { - std::cout << msg << ":" << s.x << "," << s.y << "," << s.z << "," << s.w << std::endl; -} - -void print(const std::string& msg, const int4& s) { - std::cout << msg << ":" << s.x << "," << s.y << "," << s.z << "," << s.w << std::endl; -} - -void print(const std::string& msg, const Strides& s) { - print(msg, s.strides_for_bnsh_coord); -} - #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GroupQueryAttention, \ @@ -263,9 +251,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { const size_t value_offset = static_cast(kv_num_heads * head_size); key_ptr = query_ptr + key_offset; value_ptr = key_ptr + value_offset; - - print("!is_packed_qkv, key_shape ", kv_shape); - print("!is_packed_qkv, key_strides", key_strides); } IAllocatorUniquePtr rotary_q_tmp; @@ -276,9 +261,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size); - print("do_rotary, rotary_k_shape ", int4{batch_size, sequence_length, kv_num_heads, head_size}); - print("do_rotary, rotary_k_strides", rotary_k_strides); - rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); @@ -313,8 +295,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { key_ptr = reinterpret_cast(rotary_k_tmp.get()); query_strides = rotary_q_strides; key_strides = rotary_k_strides; - print("do_rotary, key_shape ", kv_shape); - print("do_rotary, key_strides", key_strides); } const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; @@ -324,21 +304,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); if (parameters.is_prompt) { - std::cout << "is_prompt" << std::endl; // copy prompt kv to present kv ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); - print("is_prompt, key_shape ", kv_shape); - print("is_prompt, key_strides", key_strides); } else { - std::cout << "!is_prompt" << std::endl; const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_key->DataRaw()); const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_value->DataRaw()); parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME: if (!parameters.kv_share_buffer) { - std::cout << "!kv_share_buffer" << std::endl; // copy past to present, // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are // not the same, aka, can not be as simple as strided @@ -347,7 +322,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); } else { - std::cout << "kv_share_buffer" << std::endl; // In the case of share buffer ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); From 6c4e612520b7cbbf8bf581297d73bb2a066e0d61 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Thu, 13 Jun 2024 09:27:05 +0000 Subject: [PATCH 09/18] workaround: add flash_attn test to ci --- .../orttraining-pai-ci-pipeline.yml | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index 7ada4ee6757c9..001062452644e 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -255,6 +255,33 @@ jobs: arguments: -n $(Agent.Name) -d $HIP_VISIBLE_DEVICES -r $DRIVER_RENDER displayName: 'Check ROCm Environment' + # TODO: move to use ci_build/build.py driven tests + - task: CmdLine@2 + inputs: + script: |- + docker run --rm \ + --security-opt seccomp=unconfined \ + --shm-size=1024m \ + --device=/dev/kfd \ + --device=/dev/dri/renderD$DRIVER_RENDER \ + --group-add $(video) \ + --group-add $(render) \ + --user onnxruntimedev \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + -e OPENBLAS_NUM_THREADS=1 \ + -e OPENMP_NUM_THREADS=1 \ + -e MKL_NUM_THREADS=1 \ + -e PYTHONPATH=/build/$(BuildConfig) \ + onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-test \ + /bin/bash -c " + set -ex; \ + pip install -r /onnxruntime_src/tools/ci_build/requirements-transformers-test.txt; \ + pytest /onnxruntime_src/onnxruntime/test/python/transformers/test_flash_attn_rocm.py -v -n 4 --reruns 1" + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Run tranformers tests' + condition: succeededOrFailed() + - task: CmdLine@2 inputs: script: |- From e9f6d13fb16cbfbf476b238b3d08e521b406fd17 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Fri, 14 Jun 2024 06:56:16 +0000 Subject: [PATCH 10/18] add gpu arch checking warning log --- .../rocm/bert/group_query_attention.cu | 21 +++++++++++++++++++ .../rocm/bert/group_query_attention.h | 4 ++++ 2 files changed, 25 insertions(+) diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 04130b46b94e8..7730b0205b69c 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -140,6 +140,12 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); } +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + template Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); @@ -154,6 +160,21 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { const Tensor* sin_cache = ctx->Input(8); auto& device_prop = GetDeviceProp(); + std::call_once( + arch_checking_, + [](const hipDeviceProp_t& device_prop) { + if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos && + std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) { + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention currently only supports ck_tile fmha backend which only supports " + << "CDNA2 and CDNA3 archs."; + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention running on an unsuppoted GPU may result in " + << "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error."; + } + }, + device_prop); + GroupQueryAttentionParameters parameters; using HipT = typename ToHipType::MappedType; diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h index 4d40b5049a8ee..ce0de1f761aa5 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/providers/rocm/rocm_kernel.h" namespace onnxruntime { @@ -27,6 +28,9 @@ class GroupQueryAttention final : public RocmKernel { bool do_rotary_; bool rotary_interleaved_; float scale_; + + private: + static std::once_flag arch_checking_; }; } // namespace rocm From 2b0c46ed782561a6c5ad92c6f84bbaaeeab91b05 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Thu, 20 Jun 2024 07:26:00 +0000 Subject: [PATCH 11/18] fix: build without ck tile --- onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 7730b0205b69c..9aadb5c5caf88 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -148,6 +148,7 @@ std::once_flag GroupQueryAttention::arch_checking_{}; template Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { +#if USE_COMPOSABLE_KERNEL_CK_TILE auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); const Tensor* query = ctx->Input(0); const Tensor* key = ctx->Input(1); @@ -512,6 +513,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { HIP_RETURN_IF_ERROR(hipGetLastError()); return Status::OK(); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tiles to be enabled"); +#endif } } // namespace rocm From 6091a69b2dc444ee8ef653d4404b6463baa02ca5 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Fri, 28 Jun 2024 08:23:01 +0000 Subject: [PATCH 12/18] test: update ci pytorch and triton version to fix tests which have failed with nan and inf from reference values --- tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 59f6c0ab2136c..b94826ae0e4bc 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -77,7 +77,11 @@ RUN ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_ENVIRONMENT_PATH}/bi RUN export MAJOR=$(cut -d '.' -f 1 <<< "$ROCM_VERSION") && \ export MINOR=$(cut -d '.' -f 2 <<< "$ROCM_VERSION") && \ export PATCH=$(cut -d '.' -f 3 <<< "$ROCM_VERSION") && \ - pip install torch==2.0.1 torchvision==0.15.2 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-${MAJOR}.${MINOR}/ && \ + if (( MAJOR >= 6 )); then \ + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm${MAJOR}.${MINOR} ; \ + else \ + pip install torch==2.0.1 torchvision==0.15.2 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-${MAJOR}.${MINOR}/ ; \ + fi && \ pip install torch-ort --no-dependencies ##### Install Cupy to decrease CPU utilization From 8ca063405415a0c122e7ab761aa7188127161efb Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 1 Jul 2024 04:12:18 +0000 Subject: [PATCH 13/18] format --- .../contrib_ops/cuda/bert/attention_impl.h | 9 +++++---- .../cuda/bert/attention_strided_copy.cu | 9 +++++---- .../contrib_ops/rocm/bert/attention_impl.h | 9 +++++---- .../rocm/bert/group_query_attention.cu | 15 +++++++++------ 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 2269ed143e929..fda7ac2784129 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -177,10 +177,11 @@ Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, T* present); template -Status LaunchStridedCopy(cudaStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) - T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) - int max_threads_per_block); +Status LaunchStridedCopy( + cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); template Status LaunchStridedCopy(cudaStream_t stream, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu index 2530398504ac6..66e56e701c558 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu @@ -84,10 +84,11 @@ template using ToBytes = typename ToByteType::T; template -Status LaunchStridedCopy(cudaStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) - T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) - int max_threads_per_block) { +Status LaunchStridedCopy( + cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block) { int batch_size = in_shape.x; int num_heads = in_shape.y; int sequence_length = in_shape.z; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 5e333c60d5968..349df045becf2 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -170,10 +170,11 @@ Status ClassifyAttentionMode(AttentionType type, const std::vector& present); template -Status LaunchStridedCopy(hipStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) - T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) - int max_threads_per_block); +Status LaunchStridedCopy( + hipStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); template Status LaunchStridedCopy(hipStream_t stream, diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 9aadb5c5caf88..33dd358708078 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -367,8 +367,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { } static_assert(std::is_same_v); - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; - // TODO: + const float scale = parameters.scale == 0.0f + ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; bias_enum bias_type = bias_enum::no_bias; mask_info mask = [&]() { @@ -391,10 +392,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit(hip_stream, seqstart_q_tmp.get(), batch_size, - query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); - ORT_RETURN_IF_ERROR(LaunchSeqStartInit(hip_stream, seqstart_k_tmp.get(), batch_size, - present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_q_tmp.get(), batch_size, + query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_k_tmp.get(), batch_size, + present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); fmha_fwd_args args{ query_ptr, From e22dfb9e1965c87af9a46dd4cd132f61b9423fa3 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 1 Jul 2024 04:17:34 +0000 Subject: [PATCH 14/18] remove unused param is_input_bnsh_format from strided version LaunchRotaryEmbeddingKernel --- onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu | 4 ++-- onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h | 1 - onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index c7fa071c27daf..316fdfb25f45c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -98,7 +98,7 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu sequence_length, num_heads, head_size, rotary_embedding_dim, max_sequence_length, position_ids_format, interleaved, - max_threads_per_block, is_input_bnsh_format, + max_threads_per_block, in_strides, out_strides); } @@ -108,7 +108,7 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu const int sequence_length, const int num_heads, const int head_size, const int rotary_embedding_dim, const int /*max_sequence_length*/, const int position_ids_format, const bool interleaved, - const int max_threads_per_block, const bool is_input_bnsh_format, + const int max_threads_per_block, int4 in_strides, int4 out_strides // strides in bnsh coord ) { // Note: Current implementation assumes head_size <= max_threads_per_block diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index c52b6d18141f2..dd0ac6a6e3274 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -45,7 +45,6 @@ Status LaunchRotaryEmbeddingKernel( const int position_ids_format, const bool interleaved, const int max_threads_per_block, - const bool is_input_bnsh_format, int4 in_strides, int4 out_strides); diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 33dd358708078..e53b3e3a67140 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -299,7 +299,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.num_heads, parameters.head_size, parameters.rotary_dim, parameters.seqlen_present_kv_cache, /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, /*transposed*/ false, + max_thr_per_blk, query_strides.ForBNSHCoord(), rotary_q_strides.ForBNSHCoord())); ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, @@ -310,7 +310,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.kv_num_heads, parameters.head_size, parameters.rotary_dim, parameters.seqlen_present_kv_cache, /*position_ids_format*/ 1, parameters.rotary_interleaved, - max_thr_per_blk, /*transposed*/ false, + max_thr_per_blk, key_strides.ForBNSHCoord(), rotary_k_strides.ForBNSHCoord())); query_ptr = reinterpret_cast(rotary_q_tmp.get()); From 789fee737ca2cd74dd55d437571d079d0e3e805f Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 1 Jul 2024 10:32:18 +0000 Subject: [PATCH 15/18] make onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE depends on onnxruntime_USE_COMPOSABLE_KERNEL --- cmake/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ce22def914851..4dd4cd7d34bbf 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -240,7 +240,7 @@ option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF) # composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON) -option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON) +cmake_dependent_option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON "onnxruntime_USE_COMPOSABLE_KERNEL" OFF) option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF) option(onnxruntime_USE_TRITON_KERNEL "Enable triton compiled kernel" OFF) option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF) From b973217f8868681e15f990b79f2414355f660326 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Mon, 1 Jul 2024 10:56:16 +0000 Subject: [PATCH 16/18] skip test_flash_attn_rocm on CUDA platform --- .../test/python/transformers/test_flash_attn_rocm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py index 628372300aa65..b476edb3b9cd5 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -2,6 +2,7 @@ import unittest import torch +import onnxruntime from parameterized import parameterized from test_flash_attn_cuda import ( Formats, @@ -22,6 +23,8 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte return if platform.system() != "Linux": return + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + return print("------- FLASH ATTENTION (PROMPT CASE) --------") parity_check_gqa_prompt( @@ -50,9 +53,10 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle config.ep = "ROCMExecutionProvider" if not torch.cuda.is_available(): return - config.ep = "ROCMExecutionProvider" if platform.system() != "Linux": return + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + return print("------- FLASH ATTENTION (TOKEN GEN) -------") parity_check_gqa_past( From c3c7089da14c8e3461e8cbbfa7940030aa50e2b9 Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 2 Jul 2024 00:46:39 +0000 Subject: [PATCH 17/18] fix lint and ci --- onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu | 2 +- onnxruntime/test/python/transformers/test_flash_attn_rocm.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index e53b3e3a67140..10e62a2a3d70e 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "ck_tile/core/numeric/integer.hpp" #include "core/providers/rocm/rocm_common.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/rocm/bert/group_query_attention.h" @@ -11,6 +10,7 @@ #include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" #ifdef USE_COMPOSABLE_KERNEL_CK_TILE +#include "ck_tile/core/numeric/integer.hpp" #include "fmha_fwd.hpp" #endif diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py index b476edb3b9cd5..fe7e39722237f 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -2,7 +2,6 @@ import unittest import torch -import onnxruntime from parameterized import parameterized from test_flash_attn_cuda import ( Formats, @@ -14,6 +13,8 @@ parity_check_gqa_prompt_no_buff, ) +import onnxruntime + class TestGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_flash_attention_test_cases()) From f4355d46d5b33aef3233a88fa2ea15fab29b122c Mon Sep 17 00:00:00 2001 From: Guangyun Han Date: Tue, 2 Jul 2024 00:48:45 +0000 Subject: [PATCH 18/18] fix typo --- onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu | 2 +- onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index 316fdfb25f45c..ad0a83c9cde65 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -116,7 +116,7 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); - // strides in cannoical bnsh coord, h is always contiguous (dim_stride == 1) + // strides in canonical bnsh coord, h is always contiguous (dim_stride == 1) ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous"); int tpb = (head_size + 31) / 32 * 32; diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 10e62a2a3d70e..92c780d4a9d41 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -517,7 +517,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); #else - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tiles to be enabled"); + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tile to be enabled"); #endif }