diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index 9bc2bdd208a92..4140eeee0d111 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -94,6 +94,11 @@ set(contrib_ops_excluded_files
"cuda_contrib_kernels.h"
"inverse.cc"
"fused_conv.cc"
+ "bert/group_query_attention_helper.h"
+ "bert/group_query_attention.h"
+ "bert/group_query_attention.cc"
+ "bert/group_query_attention_impl.h"
+ "bert/group_query_attention_impl.cu"
)
if (NOT onnxruntime_ENABLE_ATEN)
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index ed1049b0bd73a..8e86862a62e7d 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2422,14 +2422,14 @@ This version of the operator has been available since version 1 of the 'com.micr
When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.
-#### Outputs (1 - 3)
+#### Outputs
- output : T
- 3D output tensor with shape (batch_size, sequence_length, hidden_size)
-- present_key (optional) : T
+- present_key : T
- present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-- present_value (optional) : T
+- present_value : T
- present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index eb9e6d5c62467..16ce3a899fb5e 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -374,6 +374,7 @@ Status EfficientAttention(
p.num_heads = parameters.num_heads;
p.sequence_length = parameters.sequence_length;
p.kv_sequence_length = parameters.total_sequence_length;
+ p.max_sequence_length = parameters.total_sequence_length;
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = parameters.is_unidirectional;
@@ -395,6 +396,7 @@ Status EfficientAttention(
p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
p.output = data.output;
+ p.is_kv_bsnh = true;
p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float))
? data.scratch
: nullptr;
diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
index ed330b0fca332..51c3d3d3a458b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
+++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
@@ -51,25 +51,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.num_keys = params.kv_sequence_length;
if (params.causal) {
- p.custom_mask_type = Attention::CausalFromTopLeft;
+ p.custom_mask_type = Attention::CausalFromBottomRight;
}
- // Input format is BxSxNxH, output is BxSxNxH
- p.q_strideH = params.qk_head_size;
- p.k_strideH = params.qk_head_size;
- p.v_strideH = params.v_head_size;
- p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
-
- p.q_strideM = params.num_heads * params.qk_head_size;
- p.k_strideM = params.num_heads * params.qk_head_size;
- p.v_strideM = params.num_heads * params.v_head_size;
- p.o_strideM = params.num_heads * params.v_head_size;
- p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
-
- p.q_strideB = static_cast(p.q_strideM) * params.sequence_length;
- p.k_strideB = static_cast(p.k_strideM) * params.kv_sequence_length;
- p.v_strideB = static_cast(p.v_strideM) * params.kv_sequence_length;
- p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0;
+ // We use max_sequence_length to calculate KV stride
+ if (params.is_kv_bsnh) {
+ // Input Q, K, V format is BxSxNxH, output is BxSxNxH
+ p.q_strideH = params.qk_head_size;
+ p.k_strideH = params.qk_head_size;
+ p.v_strideH = params.v_head_size;
+ p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
+
+ p.q_strideM = params.num_heads * params.qk_head_size;
+ p.k_strideM = params.num_heads * params.qk_head_size;
+ p.v_strideM = params.num_heads * params.v_head_size;
+ p.o_strideM = params.num_heads * params.v_head_size;
+ p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
+
+ p.q_strideB = static_cast(p.q_strideM) * params.sequence_length;
+ p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length;
+ p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length;
+ p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0;
+ } else {
+ // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH
+ p.q_strideH = params.qk_head_size;
+ p.k_strideH = params.max_sequence_length * params.qk_head_size;
+ p.v_strideH = params.max_sequence_length * params.v_head_size;
+ p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
+
+ p.q_strideM = params.num_heads * params.qk_head_size;
+ p.k_strideM = params.qk_head_size;
+ p.v_strideM = params.v_head_size;
+ p.o_strideM = params.num_heads * params.v_head_size;
+ p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
+
+ p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length;
+ p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length;
+ p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length;
+ p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0;
+ }
}
constexpr auto kernel_fn = attention_kernel_batched_impl;
diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h
index f725be8d7cf89..f16567bb6f2b7 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h
@@ -14,10 +14,12 @@ namespace cuda {
struct MemoryEfficientAttentionParams {
int32_t sm;
bool is_half;
+ bool is_kv_bsnh = true;
int32_t batch_size;
int32_t num_heads;
int32_t sequence_length;
int32_t kv_sequence_length;
+ int32_t max_sequence_length;
int32_t qk_head_size;
int32_t v_head_size;
bool causal;
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
index 67d750aeac11a..8694dc998c7a8 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
@@ -6,9 +6,8 @@
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/group_query_attention.h"
#include "contrib_ops/cuda/bert/group_query_attention_helper.h"
+#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
-// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
-// #include "contrib_ops/cpu/utils/console_dumper.h"
using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
@@ -55,6 +54,13 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
#else
disable_flash_attention_ = true;
#endif
+
+#if USE_MEMORY_EFFICIENT_ATTENTION
+ disable_memory_efficient_attention_ = sizeof(T) != 2 ||
+ ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false);
+#else
+ disable_memory_efficient_attention_ = true;
+#endif
}
template
@@ -92,18 +98,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
output_shape[2] = static_cast(parameters.hidden_size);
Tensor* output = context->Output(0, output_shape);
- std::vector present_dims;
- if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
- present_dims = {
- parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
- } else { // BNSH
- present_dims = {
- parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
- }
- TensorShape present_shape(present_dims);
- Tensor* present_key = context->Output(1, present_shape);
- Tensor* present_value = context->Output(2, present_shape);
-
#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
onnxruntime::flash::is_supported(device_prop,
@@ -143,8 +137,47 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
auto seqlens_k_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr
#endif
- // only kernel implemented for gqa right now
- ORT_ENFORCE(use_flash_attention);
+#if USE_MEMORY_EFFICIENT_ATTENTION
+ int sm = (device_prop.major * 10) + device_prop.minor;
+ bool use_memory_efficient_attention =
+ !use_flash_attention &&
+ !disable_memory_efficient_attention_ &&
+ (parameters.head_size & 7) == 0 &&
+ parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length &&
+ (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
+ has_memory_efficient_attention(sm, sizeof(T) == 2);
+ // allocate buffers
+ size_t kv_buffer_bytes = 0;
+ // need a buffer if we must ungroup kv
+ const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads);
+ if (use_memory_efficient_attention && needs_buff) {
+ kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size);
+ }
+ size_t fmha_buffer_bytes = 0;
+ if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {
+ fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float));
+ }
+ auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream());
+ auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream());
+ auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream());
+#else
+ constexpr bool use_memory_efficient_attention = false;
+ auto k_buffer = GetScratchBuffer(0, context->GetComputeStream());
+ auto v_buffer = GetScratchBuffer(0, context->GetComputeStream());
+ auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream());
+#endif
+
+ std::vector present_dims;
+ if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
+ present_dims = {
+ parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
+ } else { // BNSH
+ present_dims = {
+ parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
+ }
+ TensorShape present_shape(present_dims);
+ Tensor* present_key = context->Output(1, present_shape);
+ Tensor* present_value = context->Output(2, present_shape);
data.query = reinterpret_cast(query->Data());
data.key = reinterpret_cast(key->Data());
@@ -155,6 +188,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData());
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData());
data.use_flash_attention = use_flash_attention;
+ data.use_memory_efficient_attention = use_memory_efficient_attention;
if (softmax_lse_buffer != nullptr) {
data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get());
}
@@ -167,6 +201,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
if (seqlens_k_buffer != nullptr) {
data.seqlens_k = reinterpret_cast(seqlens_k_buffer.get());
}
+ if (k_buffer != nullptr) {
+ data.k = reinterpret_cast(k_buffer.get());
+ data.v = reinterpret_cast(v_buffer.get());
+ }
+ if (fmha_buffer != nullptr) {
+ data.fmha_buffer = reinterpret_cast(fmha_buffer.get());
+ }
cublasHandle_t cublas = GetCublasHandle(context);
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
index 72c9814fad670..a90418ec2243a 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
@@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel {
bool is_past_bsnh_;
float scale_;
bool disable_flash_attention_;
+ bool disable_memory_efficient_attention_;
};
} // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
index be8f5ca0ae3e9..8c21de9ced058 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
@@ -29,13 +29,13 @@ Status CheckInputs(const Tensor* query,
// query (Q) : (B, S, D)
// key (K) : (B, S+, D_kv)
// value (V) : (B, S+, D_kv)
+ ORT_UNUSED_PARAMETER(value);
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
AttentionQkvFormat past_kv_format = Q_K_V_BSNH;
const auto& query_dims = query->Shape().GetDims();
const auto& key_dims = key->Shape().GetDims();
- const auto& value_dims = value->Shape().GetDims();
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
@@ -47,10 +47,8 @@ Status CheckInputs(const Tensor* query,
int q_hidden_size = static_cast(query_dims[2]);
int head_size = static_cast(q_hidden_size) / num_heads;
- int kv_sequence_length = sequence_length;
- int kv_hidden_size = (key_dims.size() == 3)
- ? static_cast(key_dims[2])
- : (kv_num_heads * static_cast(key_dims[3]));
+ int kv_sequence_length = static_cast(key_dims[1]);
+ int kv_hidden_size = static_cast(key_dims[2]);
int max_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
@@ -134,63 +132,49 @@ Status CheckInputs(const Tensor* query,
"Input 'past_key' and 'past_value' shall be both present or both absent");
}
- if (key != nullptr) {
- const auto& key_dims = key->Shape().GetDims();
- if (key_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
- key_dims.size());
- }
- if (query_dims[0] != key_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'key' shall have same dim 0 (batch size)");
- }
-
- if (num_heads % kv_num_heads != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
- num_heads % kv_num_heads);
- }
- if (key_dims[2] != value_dims[2]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)");
- }
-
- qkv_format = Q_K_V_BSNH;
- kv_sequence_length = static_cast(key_dims[1]);
- } else {
+ if (key_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
+ key_dims.size());
+ }
+ if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Missing key tensor.");
+ "Input 'query' and 'key' shall have same dim 0 (batch size)");
}
- if (value != nullptr) {
- const auto& value_dims = value->Shape().GetDims();
- if (value_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
- value_dims.size());
- }
+ if (num_heads % kv_num_heads != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
+ num_heads % kv_num_heads);
+ }
- if (query_dims[0] != value_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'value' shall have same dim 0 (batch_size)");
- }
+ const auto& value_dims = value->Shape().GetDims();
+ if (value_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
+ value_dims.size());
+ }
- if (static_cast(kv_sequence_length) != value_dims[1]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
- }
+ if (query_dims[0] != value_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'value' shall have same dim 0 (batch_size)");
+ }
- if (value_dims[2] != kv_hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
- }
- } else {
+ if (static_cast(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Missing value tensor.");
+ "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
+ }
+
+ if (value_dims[2] != kv_hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
}
// When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly.
int32_t past_sequence_length = 0;
- int present_sequence_length = 0;
+ int present_sequence_length = kv_sequence_length;
if (past_seq_len != nullptr) {
+ if (past_key == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Past KV must be present as share-buffer when using past_seq_len pointer.");
+ }
if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"past_sequence_length tensor must be of one element when using past kv.");
@@ -200,6 +184,10 @@ Status CheckInputs(const Tensor* query,
} else {
past_sequence_length = static_cast(*((*past_seq_len).template Data()));
}
+ if (past_sequence_length + kv_sequence_length > max_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length");
+ }
present_sequence_length = max_sequence_length;
} else if (past_key != nullptr) {
past_sequence_length = max_sequence_length; // this is the length of past_key tensor
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
index ab3029ca34886..0455825c364a2 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
@@ -37,6 +37,7 @@ limitations under the License.
#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
+#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
@@ -47,6 +48,8 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
+////////// Auxiliary Kernels for KV prep
+
// Kernel for seqlens_k
__global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) {
int id = blockDim.x * blockIdx.x + threadIdx.x;
@@ -75,7 +78,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen,
const int present_head_stride = is_bsnh ? H : present_seqlen * H;
// past_kv: BPNH or BNPH
- // new_kv: BLNH or BNLH
+ // new_kv: BLNH
// present_kv: BTNH or BNTH, where T = P + L
const int past_seqlen = present_seqlen - new_seqlen;
@@ -95,33 +98,32 @@ __global__ void ConcatNewToPastKV(const int new_seqlen,
}
}
+// Use when (H*)*num_heads > 1024
template
__global__ void ConcatNewToPastKVLarge(const int new_seqlen,
const int H,
+ const int num_heads,
const T* past_kv,
const T* new_kv,
T* present_kv,
const bool is_bsnh) {
- // Use when (H*)*num_heads > 1024
- int h = threadIdx.x;
- const int n = threadIdx.y;
- const int s = blockIdx.x;
- const int b = blockIdx.y;
+ int i = threadIdx.x + (blockDim.x * blockIdx.x);
+ if (i < H * num_heads) {
+ const int h = i % H;
+ const int n = i / H;
+ const int s = blockIdx.y;
+ const int b = blockIdx.z;
+ const int present_seqlen = gridDim.y;
+
+ const int present_batch_stride = present_seqlen * num_heads * H;
+ const int row_stride = is_bsnh ? num_heads * H : H;
+ const int present_head_stride = is_bsnh ? H : present_seqlen * H;
+
+ // past_kv: BPNH or BNPH
+ // new_kv: BLNH
+ // present_kv: BTNH or BNTH, where T = P + L
+ const int past_seqlen = present_seqlen - new_seqlen;
- const int present_seqlen = gridDim.x;
- const int num_heads = blockDim.y;
- const int thread_stride = blockDim.x;
-
- const int present_batch_stride = present_seqlen * num_heads * H;
- const int row_stride = is_bsnh ? num_heads * H : H;
- const int present_head_stride = is_bsnh ? H : present_seqlen * H;
-
- // past_kv: BPNH or BNPH
- // new_kv: BLNH or BNLH
- // present_kv: BTNH or BNTH, where T = P + L
- const int past_seqlen = present_seqlen - new_seqlen;
-
- while (h < H) {
int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h;
if (s < past_seqlen) {
const int past_batch_stride = past_seqlen * num_heads * H;
@@ -135,133 +137,477 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen,
const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h;
present_kv[out_offset] = new_kv[in_offset];
}
- h += thread_stride;
}
}
+// Concat new to past in present. Supports past BSNH or past BNSH
template
-Status QkvToContext(
+Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters,
+ GroupQueryAttentionData& data,
+ cudaStream_t stream,
+ const int max_threads_per_block) {
+ const int batch_size = parameters.batch_size;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int present_sequence_length = parameters.present_sequence_length;
+ const int kv_num_heads = parameters.kv_num_heads;
+ const int head_size = parameters.head_size;
+ AttentionQkvFormat past_kv_format = parameters.past_kv_format;
+
+ assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
+ const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time.
+ if (H * kv_num_heads <= max_threads_per_block) {
+ const dim3 grid(present_sequence_length, batch_size, 1);
+ const dim3 block(H, kv_num_heads, 1);
+ ConcatNewToPastKV<<>>(kv_sequence_length,
+ reinterpret_cast(data.past_key),
+ reinterpret_cast(data.key),
+ reinterpret_cast(data.present_key),
+ past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ ConcatNewToPastKV<<>>(kv_sequence_length,
+ reinterpret_cast(data.past_value),
+ reinterpret_cast(data.value),
+ reinterpret_cast(data.present_value),
+ past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ } else {
+ int steps = (H * kv_num_heads + 255) / 256;
+ const dim3 grid(steps, present_sequence_length, batch_size);
+ const dim3 block(256, 1, 1);
+ ConcatNewToPastKVLarge<<>>(kv_sequence_length,
+ H,
+ kv_num_heads,
+ reinterpret_cast(data.past_key),
+ reinterpret_cast(data.key),
+ reinterpret_cast(data.present_key),
+ past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ ConcatNewToPastKVLarge<<>>(kv_sequence_length,
+ H,
+ kv_num_heads,
+ reinterpret_cast(data.past_value),
+ reinterpret_cast(data.value),
+ reinterpret_cast(data.present_value),
+ past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ }
+ return CUDA_CALL(cudaGetLastError());
+}
+
+// Kernel to append new kv to kv buffer in place
+template
+__global__ void ConcatKVInPlace(const int past_seqlen,
+ const int present_seqlen,
+ T* kv_buff,
+ const T* new_kv,
+ const bool is_bsnh) { // refers to kv buff; otherwise bnsh
+ const int h = threadIdx.x;
+ const int n = threadIdx.y;
+ const int s = blockIdx.x;
+ const int b = blockIdx.y;
+
+ const int new_seqlen = gridDim.x;
+ const int num_heads = blockDim.y;
+ const int H = blockDim.x;
+
+ const int present_batch_stride = present_seqlen * num_heads * H;
+ const int present_row_stride = is_bsnh ? num_heads * H : H;
+ const int present_head_stride = is_bsnh ? H : present_seqlen * H;
+
+ // kv_buff: BTNH or BNTH with buffered memory for new
+ // new_kv: BLNH
+
+ int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h;
+ // Note: new KV always BSNH
+ const int new_batch_stride = new_seqlen * num_heads * H;
+ const int new_row_stride = num_heads * H;
+ const int new_head_stride = H;
+ const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h;
+ kv_buff[out_offset] = new_kv[in_offset];
+}
+
+template
+__global__ void ConcatKVInPlaceLarge(const int past_seqlen,
+ const int present_seqlen,
+ const int H,
+ const int num_heads,
+ T* kv_buff,
+ const T* new_kv,
+ const bool is_bsnh) { // refers to kv buff; otherwise bnsh
+ int i = threadIdx.x + (blockDim.x * blockIdx.x);
+ if (i < H * num_heads) {
+ const int h = i % H;
+ const int n = i / H;
+ const int s = blockIdx.y;
+ const int b = blockIdx.z;
+ const int new_seqlen = gridDim.y;
+
+ const int present_batch_stride = present_seqlen * num_heads * H;
+ const int present_row_stride = is_bsnh ? num_heads * H : H;
+ const int present_head_stride = is_bsnh ? H : present_seqlen * H;
+
+ // kv_buff: BTNH or BNTH with buffered memory for new
+ // new_kv: BLNH
+
+ int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h;
+ // Note: new KV always BSNH
+ const int new_batch_stride = new_seqlen * num_heads * H;
+ const int new_row_stride = num_heads * H;
+ const int new_head_stride = H;
+ const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h;
+ kv_buff[out_offset] = new_kv[in_offset];
+ }
+}
+
+// Concat new to kv buffer in place
+template
+Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters,
+ GroupQueryAttentionData& data,
+ cudaStream_t stream,
+ const int max_threads_per_block) {
+ const int batch_size = parameters.batch_size;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int present_sequence_length = parameters.present_sequence_length;
+ const int past_sequence_length = parameters.past_sequence_length;
+ const int kv_num_heads = parameters.kv_num_heads;
+ const int head_size = parameters.head_size;
+ AttentionQkvFormat past_kv_format = parameters.past_kv_format;
+ assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
+ const int H = head_size / 4;
+ if (H * kv_num_heads <= max_threads_per_block) {
+ const dim3 grid(kv_sequence_length, batch_size, 1);
+ const dim3 block(H, kv_num_heads, 1);
+ ConcatKVInPlace<<>>(past_sequence_length,
+ present_sequence_length,
+ reinterpret_cast(data.present_key),
+ reinterpret_cast(data.key),
+ past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ ConcatKVInPlace<<>>(past_sequence_length,
+ present_sequence_length,
+ reinterpret_cast(data.present_value),
+ reinterpret_cast(data.value),
+ past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ } else {
+ int steps = int(ceil(float(H * kv_num_heads) / 256.0));
+ const dim3 grid(steps, kv_sequence_length, batch_size);
+ const dim3 block(256, 1, 1);
+ ConcatKVInPlaceLarge<<>>(past_sequence_length,
+ present_sequence_length,
+ H,
+ kv_num_heads,
+ reinterpret_cast(data.present_key),
+ reinterpret_cast(data.key),
+ past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ ConcatKVInPlaceLarge<<>>(past_sequence_length,
+ present_sequence_length,
+ H,
+ kv_num_heads,
+ reinterpret_cast(data.present_value),
+ reinterpret_cast(data.value),
+ past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ }
+ return CUDA_CALL(cudaGetLastError());
+}
+
+// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh
+template
+__global__ void Ungroup(const T* kv_in,
+ T* kv_out,
+ const int in_seqlen,
+ const int kv_num_heads,
+ const bool is_bsnh) {
+ const int h = threadIdx.x;
+ const int out_n = threadIdx.y;
+ const int s = blockIdx.x;
+ const int b = blockIdx.y;
+
+ const int out_seqlen = gridDim.x;
+ const int q_num_heads = blockDim.y;
+ const int H = blockDim.x;
+
+ const int q_kv_head_ratio = q_num_heads / kv_num_heads;
+ const int out_batch_stride = out_seqlen * q_num_heads * H;
+ const int out_row_stride = is_bsnh ? q_num_heads * H : H;
+ const int out_head_stride = is_bsnh ? H : out_seqlen * H;
+
+ const int in_batch_stride = in_seqlen * kv_num_heads * H;
+ const int in_row_stride = is_bsnh ? kv_num_heads * H : H;
+ const int in_head_stride = is_bsnh ? H : in_seqlen * H;
+ const int in_n = out_n / q_kv_head_ratio;
+
+ const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h;
+ const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h;
+ kv_out[out_offset] = kv_in[in_offset];
+}
+
+template
+__global__ void UngroupLarge(const T* kv_in,
+ T* kv_out,
+ const int H,
+ const int in_seqlen,
+ const int q_num_heads,
+ const int kv_num_heads,
+ const bool is_bsnh) {
+ int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements
+ if (i < H * q_num_heads) {
+ const int out_seqlen = gridDim.y;
+ const int s = blockIdx.y;
+ const int b = blockIdx.z;
+
+ const int q_kv_head_ratio = q_num_heads / kv_num_heads;
+ const int out_batch_stride = out_seqlen * q_num_heads * H;
+ const int out_row_stride = is_bsnh ? q_num_heads * H : H;
+ const int out_head_stride = is_bsnh ? H : out_seqlen * H;
+
+ const int in_batch_stride = in_seqlen * kv_num_heads * H;
+ const int in_row_stride = is_bsnh ? kv_num_heads * H : H;
+ const int in_head_stride = is_bsnh ? H : in_seqlen * H;
+
+ const int h = i % H;
+ const int out_n = i / H;
+ const int in_n = out_n / q_kv_head_ratio;
+ const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h;
+ const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h;
+ kv_out[out_offset] = kv_in[in_offset];
+ }
+}
+
+// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it.
+Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters,
+ float2* k_buff, float2* v_buff,
+ const float2* k_og, const float2* v_og,
+ const int buff_seqlen, const int og_seqlen,
+ const bool is_bsnh,
+ cudaStream_t stream,
+ const int max_threads_per_block) {
+ const int batch_size = parameters.batch_size;
+ const int num_heads = parameters.num_heads;
+ const int kv_num_heads = parameters.kv_num_heads;
+ const int head_size = parameters.head_size;
+
+ const int H = head_size / 4;
+ if (H * num_heads <= max_threads_per_block) {
+ const dim3 grid(buff_seqlen, batch_size, 1);
+ const dim3 block(H, num_heads, 1);
+ Ungroup<<>>(k_og,
+ k_buff,
+ og_seqlen,
+ kv_num_heads,
+ is_bsnh);
+ Ungroup<<>>(v_og,
+ v_buff,
+ og_seqlen,
+ kv_num_heads,
+ is_bsnh);
+ } else {
+ int steps = int(ceil(float(H * num_heads) / 256.0));
+ const dim3 grid(steps, buff_seqlen, batch_size);
+ const dim3 block(256, 1, 1);
+ UngroupLarge<<>>(k_og,
+ k_buff,
+ H,
+ og_seqlen,
+ num_heads,
+ kv_num_heads,
+ is_bsnh);
+ UngroupLarge<<>>(v_og,
+ v_buff,
+ H,
+ og_seqlen,
+ num_heads,
+ kv_num_heads,
+ is_bsnh);
+ }
+ return CUDA_CALL(cudaGetLastError());
+}
+
+////////// Launch Kernels
+
+#if USE_FLASH_ATTENTION
+template
+Status FlashAttention(
const cudaDeviceProp& device_prop,
- cublasHandle_t& cublas,
- Stream* ort_stream,
+ cudaStream_t stream,
contrib::GroupQueryAttentionParameters& parameters,
- GroupQueryAttentionData& data) {
- assert(data.use_flash_attention);
+ GroupQueryAttentionData& data,
+ float scale) {
+ const int max_threads_per_block = device_prop.maxThreadsPerBlock;
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int present_sequence_length = parameters.present_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int kv_num_heads = parameters.kv_num_heads;
+ const int head_size = parameters.head_size;
+ AttentionQkvFormat past_kv_format = parameters.past_kv_format;
-#if USE_FLASH_ATTENTION
- auto stream = static_cast(ort_stream->GetHandle());
+ void* query = reinterpret_cast(const_cast(data.query));
+ void* key = reinterpret_cast(const_cast(data.key));
+ void* value = reinterpret_cast(const_cast(data.value));
+
+ bool is_causal = parameters.is_unidirectional;
+
+ if (data.past_key != nullptr && data.past_key == data.present_key) {
+ // Share buffer case
+ void* present_key = reinterpret_cast(const_cast(data.present_key));
+ void* present_value = reinterpret_cast(const_cast(data.present_value));
+
+ // Launch kernel to copy seqlen
+ int thr_per_blk = 256;
+ int blk_in_grid = ceil(float(batch_size) / thr_per_blk);
+ repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size);
+
+ bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
+ device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse),
+ reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads,
+ head_size, sequence_length, present_sequence_length, kv_sequence_length,
+ scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum),
+ reinterpret_cast(data.out_accum)));
+
+ } else {
+ // Not share buffer or no past (prompt generation)
+ // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient
+ ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
+
+ void* present_key = reinterpret_cast(const_cast(data.present_key));
+ void* present_value = reinterpret_cast(const_cast(data.present_value));
+
+ bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
+ device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse),
+ batch_size, num_heads, kv_num_heads, head_size,
+ sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits,
+ reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh));
+ }
+
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);
+
+ return Status::OK();
+}
+#endif
+
+#if USE_MEMORY_EFFICIENT_ATTENTION
+template
+Status EfficientAttention(
+ const cudaDeviceProp& device_prop,
+ cudaStream_t stream,
+ contrib::GroupQueryAttentionParameters& parameters,
+ GroupQueryAttentionData& data,
+ float scale) {
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
+ const int past_sequence_length = parameters.past_sequence_length;
const int present_sequence_length = parameters.present_sequence_length;
const int num_heads = parameters.num_heads;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
- const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(head_size)) : parameters.scale;
- if (data.use_flash_attention) {
- assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
- assert(parameters.num_heads % parameters.kv_num_heads == 0);
-
- void* query = reinterpret_cast(const_cast(data.query));
- void* key = reinterpret_cast(const_cast(data.key));
- void* value = reinterpret_cast(const_cast(data.value));
-
- bool is_causal = parameters.is_unidirectional;
-
- if (data.past_key == nullptr && data.present_key == nullptr) {
- ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
- device_prop, stream, query, key, value, data.output, reinterpret_cast(data.softmax_lse),
- parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, head_size,
- parameters.sequence_length, parameters.kv_sequence_length, scale, is_causal, parameters.num_splits,
- reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum)));
-
- } else if (data.past_key == data.present_key) {
- // Assume past and present kv share buffer.
- assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
- assert(parameters.past_sequence_length >= 0);
- assert(data.past_value != nullptr);
-
- void* present_key = reinterpret_cast(const_cast(data.present_key));
- void* present_value = reinterpret_cast(const_cast(data.present_value));
-
- // Launch kernel to copy seqlen
- int thr_per_blk = 256;
- int blk_in_grid = ceil(float(batch_size) / thr_per_blk);
- repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size);
-
- DUMP_TENSOR_INIT();
- DUMP_TENSOR("seqlens_k", data.seqlens_k, 1, batch_size);
-
- bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
- ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
- device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse),
- reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads,
- head_size, sequence_length, present_sequence_length, kv_sequence_length,
- scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum),
- reinterpret_cast(data.out_accum)));
-
- } else if (data.present_key != nullptr && (data.past_key != nullptr || kv_sequence_length == present_sequence_length)) {
- assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
- // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient
- if (head_size % 4 != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "requires head_size be divisible by 4");
- }
- const int H = head_size / 4;
- if (H * kv_num_heads <= max_threads_per_block) {
- const dim3 grid(present_sequence_length, batch_size, 1);
- const dim3 block(H, kv_num_heads, 1);
- ConcatNewToPastKV<<>>(kv_sequence_length,
- reinterpret_cast(data.past_key),
- reinterpret_cast(data.key),
- reinterpret_cast(data.present_key),
- past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
- ConcatNewToPastKV<<>>(kv_sequence_length,
- reinterpret_cast(data.past_value),
- reinterpret_cast(data.value),
- reinterpret_cast(data.present_value),
- past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
- } else {
- const dim3 grid(present_sequence_length, batch_size, 1);
- const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1);
- ConcatNewToPastKVLarge<<>>(kv_sequence_length,
- H,
- reinterpret_cast(data.past_key),
- reinterpret_cast(data.key),
- reinterpret_cast(data.present_key),
- past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
- ConcatNewToPastKVLarge<<>>(kv_sequence_length,
- H,
- reinterpret_cast(data.past_value),
- reinterpret_cast(data.value),
- reinterpret_cast(data.present_value),
- past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
- }
-
- void* present_key = reinterpret_cast(const_cast(data.present_key));
- void* present_value = reinterpret_cast(const_cast(data.present_value));
-
- // Launch kernel to copy seqlen
- int thr_per_blk = 256;
- int blk_in_grid = ceil(float(batch_size) / thr_per_blk);
- repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size);
-
- bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
- ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
- device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse),
- batch_size, num_heads, kv_num_heads, head_size,
- sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits,
- reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh));
+ const void* query = reinterpret_cast(data.query);
+ const void* key = reinterpret_cast(data.key);
+ const void* value = reinterpret_cast(data.value);
+ if (data.past_key != nullptr) {
+ // Past key case
+ // concatenate new kv to past kv
+ if (data.past_key == data.present_key) {
+ ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block));
+ } else {
+ ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
}
+ const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
+ if (num_heads == kv_num_heads) {
+ // Use present kv directly if not grouped
+ key = reinterpret_cast(data.present_key);
+ value = reinterpret_cast(data.present_value);
+ } else {
+ // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path
+ float2* k_buff = reinterpret_cast(data.k);
+ float2* v_buff = reinterpret_cast(data.v);
+ const float2* k_og = reinterpret_cast(data.present_key);
+ const float2* v_og = reinterpret_cast(data.present_value);
+ ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, past_sequence_length + kv_sequence_length,
+ present_sequence_length, is_bsnh, stream, max_threads_per_block));
+ key = reinterpret_cast(data.k);
+ value = reinterpret_cast(data.v);
+ }
+ } else if (num_heads == kv_num_heads) {
+ // no past or present and no need to ungroup... still copy kv into present buffer
+ ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
+ key = reinterpret_cast(data.present_key);
+ value = reinterpret_cast(data.present_value);
+ } else {
+ // intermediate buffer so q and kv have same num heads... still copy kv into present buffer
+ ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
+ float2* k_buff = reinterpret_cast(data.k);
+ float2* v_buff = reinterpret_cast(data.v);
+ const float2* k_og = reinterpret_cast(data.present_key);
+ const float2* v_og = reinterpret_cast(data.present_value);
+ ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, kv_sequence_length,
+ kv_sequence_length, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH, stream,
+ max_threads_per_block));
+ key = reinterpret_cast(data.k);
+ value = reinterpret_cast(data.v);
+ }
+
+ MemoryEfficientAttentionParams p;
+ p.sm = device_prop.major * 10 + device_prop.minor;
+ p.is_half = sizeof(T) == 2;
+ p.batch_size = batch_size;
+ p.num_heads = num_heads;
+ p.sequence_length = sequence_length;
+ p.kv_sequence_length = past_sequence_length + kv_sequence_length;
+ p.max_sequence_length = (num_heads == kv_num_heads) ? present_sequence_length : past_sequence_length + kv_sequence_length;
+ p.qk_head_size = head_size;
+ p.v_head_size = head_size;
+ p.causal = parameters.is_unidirectional;
+ p.scale = scale;
+ p.seqlen_k_ptr = nullptr;
+ p.seqstart_q_ptr = nullptr;
+ p.seqstart_k_ptr = nullptr;
+ p.query = query;
+ p.key = key;
+ p.value = value;
+ p.attn_bias = nullptr;
+ p.is_attn_bias_batched = false;
+ p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
+ p.output = data.output;
+ p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float))
+ ? data.fmha_buffer
+ : nullptr;
+ p.stream = stream;
+ run_memory_efficient_attention(p);
+
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
+
+ return Status::OK();
+}
+#endif
+
+////////// API Functions
+
+template
+Status QkvToContext(
+ const cudaDeviceProp& device_prop,
+ cublasHandle_t& cublas,
+ Stream* ort_stream,
+ contrib::GroupQueryAttentionParameters& parameters,
+ GroupQueryAttentionData& data) {
+ auto stream = static_cast(ort_stream->GetHandle());
+ const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale;
- DUMP_TENSOR_INIT();
- DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);
+#if USE_FLASH_ATTENTION
+ if (data.use_flash_attention) {
+ return FlashAttention(device_prop, stream, parameters, data, scale);
+ }
+#endif
- return Status::OK();
+#if USE_MEMORY_EFFICIENT_ATTENTION
+ if (data.use_memory_efficient_attention) {
+ return EfficientAttention(device_prop, stream, parameters, data, scale);
}
#endif
+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet.");
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
index 0bad9eeb61231..8412631078e6a 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
@@ -14,19 +14,28 @@ namespace cuda {
template
struct GroupQueryAttentionData {
+ // Input Tensors
const T* query = nullptr;
const T* key = nullptr;
const T* value = nullptr;
const T* past_key = nullptr;
const T* past_value = nullptr;
+ // Flash buffers
T* softmax_lse = nullptr;
T* softmax_lse_accum = nullptr;
T* out_accum = nullptr;
int* seqlens_k = nullptr;
+ // Memory Efficient buffers
+ T* fmha_buffer = nullptr;
+ T* k = nullptr;
+ T* v = nullptr;
+ // Output Tensors
T* output = nullptr;
T* present_key = nullptr;
T* present_value = nullptr;
+ // Kernel Flags
bool use_flash_attention = false;
+ bool use_memory_efficient_attention = false;
};
template
diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
index aba0efdbd7d5f..d7aeef1501cd6 100644
--- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
@@ -507,10 +507,12 @@ Status FusedScaledDotProductAttentionCutlass(
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
p.is_half = sizeof(T) == 2;
+ p.is_kv_bsnh = true;
p.batch_size = parameters.batch_size;
p.num_heads = parameters.num_heads;
p.sequence_length = parameters.sequence_length;
p.kv_sequence_length = parameters.sequence_length;
+ p.max_sequence_length = parameters.sequence_length;
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = false;
diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
index e09fd9e6b36e5..3fe9dbf8ed34a 100644
--- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
@@ -688,6 +688,7 @@ Status FusedAttentionCutlass(
p.num_heads = parameters.num_heads;
p.sequence_length = parameters.sequence_length;
p.kv_sequence_length = parameters.sequence_length;
+ p.max_sequence_length = parameters.sequence_length;
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = false;
@@ -702,6 +703,7 @@ Status FusedAttentionCutlass(
p.attn_bias = data.relative_position_bias;
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
p.output = data.output;
+ p.is_kv_bsnh = true;
p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float))
? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v)))
: nullptr;
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 76c3f8716ff09..5bc18a4e69b47 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1051,15 +1051,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key"
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
"kv_sequence_length.",
- "T",
- OpSchema::Optional)
+ "T")
.Output(2,
"present_value",
"present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value"
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
"kv_sequence_length.",
- "T",
- OpSchema::Optional)
+ "T")
.TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)", "tensor(int64)"}, "Constrain past sequence length to int tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py
index 04351cd6e6782..319fed87dc9eb 100644
--- a/onnxruntime/test/python/transformers/test_flash_attn.py
+++ b/onnxruntime/test/python/transformers/test_flash_attn.py
@@ -10,7 +10,10 @@
# license information.
# -------------------------------------------------------------------------
import math
+import os
+import platform
import random
+import unittest
import numpy
import torch
@@ -22,6 +25,8 @@
torch.manual_seed(0)
+pipeline_mode = True # Reduces number of tests so pipeline doesn't time out
+
class Formats:
BSNH = 0
@@ -159,7 +164,7 @@ def create_multihead_attention_graph(config):
return model.SerializeToString()
-def create_group_query_attention_graph_no_past(config, causal=False):
+def create_group_query_attention_graph_no_past(config, causal=False, present_kv_format=Formats.BSNH):
nodes = [
helper.make_node(
"GroupQueryAttention",
@@ -168,11 +173,12 @@ def create_group_query_attention_graph_no_past(config, causal=False):
"key",
"value",
],
- ["output"],
+ ["output", "present_key", "present_value"],
"GroupQueryAttention_0",
num_heads=config.num_heads,
kv_num_heads=config.kv_num_heads,
unidirectional=1 if causal else 0,
+ is_past_bsnh=1 if present_kv_format == Formats.BSNH else 0,
domain="com.microsoft",
),
]
@@ -213,6 +219,26 @@ def create_group_query_attention_graph_no_past(config, causal=False):
TensorProto.FLOAT16,
[config.batch_size, config.sequence_length, config.num_heads * config.head_size],
),
+ helper.make_tensor_value_info(
+ "present_key",
+ TensorProto.FLOAT16,
+ [
+ config.batch_size,
+ config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads,
+ config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length,
+ config.head_size,
+ ],
+ ),
+ helper.make_tensor_value_info(
+ "present_value",
+ TensorProto.FLOAT16,
+ [
+ config.batch_size,
+ config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads,
+ config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length,
+ config.head_size,
+ ],
+ ),
]
graph = helper.make_graph(
@@ -514,7 +540,6 @@ def generate_token_offset(cu_seqlens, max_seqlen):
return numpy.asarray(token_offset + token_padset, dtype=numpy.int32)
-# TODO(aciddelgado): rename
def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False):
onnx_model_str = create_packed_multihead_attention_graph(config)
qkv_unpad = torch.swapdims(qkv_unpad, 1, 2)
@@ -548,8 +573,8 @@ def mha_func(q, k, v, config):
return output
-def gqa_no_past_func(q, k, v, config, causal=True):
- onnx_model_str = create_group_query_attention_graph_no_past(config, causal)
+def gqa_no_past_func(q, k, v, config, causal=True, present_kv_format=Formats.BSNH):
+ onnx_model_str = create_group_query_attention_graph_no_past(config, causal, present_kv_format=present_kv_format)
q = torch.reshape(q, (config.batch_size, config.sequence_length, -1))
k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1))
v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1))
@@ -560,7 +585,7 @@ def gqa_no_past_func(q, k, v, config, causal=True):
}
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"])
- ort_output = ort_session.run(None, ort_inputs)
+ ort_output, _, _ = ort_session.run(None, ort_inputs)
ort_output = numpy.array(ort_output)
output = torch.tensor(ort_output)
return output
@@ -689,17 +714,12 @@ def attention_ref(
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if causal:
- # causal_mask = torch.triu(
- # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
- # )
causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device)
scores.masked_fill_(causal_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1)
if causal: # Some rows are completely masked out so we fill them with zero instead of NaN
attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
- # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
- # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
@@ -1072,12 +1092,6 @@ def parity_check_gqa_past_no_buff(
out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size))
out = out.detach().cpu().numpy()
- # print(present_k[0, 0, config.past_sequence_length, :10])
- # print(k_cache_ref[0, 0, config.past_sequence_length, :10])
- # print(k_cache_ref.shape)
-
- # print(present_k - k_cache_ref.detach().cpu().numpy())
-
# Make sure past-present buffer updating correctly
if past_format == Formats.BSNH:
assert numpy.allclose(
@@ -1141,84 +1155,185 @@ def parity_check_gqa_past_no_buff(
)
+class TestMHA(unittest.TestCase):
+ def test_packed_mha(self):
+ if not torch.cuda.is_available() or platform.system() != "Linux":
+ return
+ major, _ = torch.cuda.get_device_capability()
+ if major < 8:
+ return
+ print("-------- TEST PACKED MHA ---------")
+ batches = [2] if pipeline_mode else [1, 5]
+ seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]
+ num_h = [1, 3] if pipeline_mode else [1, 6, 16]
+ h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
+ for b in batches:
+ for s in seqs:
+ for n in num_h:
+ for h in h_sizes:
+ config = Config(b, s, s, 0, n, n, h)
+ parity_check_mha(config, True)
+
+ def test_mha(self):
+ if not torch.cuda.is_available() or platform.system() != "Linux":
+ return
+ major, _ = torch.cuda.get_device_capability()
+ if major < 8:
+ return
+ print("-------- TEST MHA ---------")
+ batches = [2] if pipeline_mode else [1, 5]
+ seqs = (
+ [(1, 128), (113, 211), (2048, 2048)]
+ if pipeline_mode
+ else [
+ (113, 203),
+ (128, 217),
+ (113, 211),
+ (108, 256),
+ (256, 512),
+ (512, 256),
+ (1024, 1024),
+ (1023, 1024),
+ (1024, 1023),
+ (2048, 2048),
+ ]
+ )
+ num_h = [1, 3] if pipeline_mode else [1, 6, 16]
+ h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
+ for b in batches:
+ for s, s2 in seqs:
+ for n in num_h:
+ for h in h_sizes:
+ config = Config(b, s, s2, 0, n, n, h)
+ parity_check_mha(config, False)
+
+
+class TestGQA(unittest.TestCase):
+ def test_gqa_no_past(self):
+ if not torch.cuda.is_available():
+ return
+ major, minor = torch.cuda.get_device_capability()
+ torch.manual_seed(69)
+ print("-------- TEST GQA ---------")
+ batches = [2] if pipeline_mode else [1, 5]
+ seqs = (
+ [(1, 128), (113, 211), (2048, 2048)]
+ if pipeline_mode
+ else [
+ (113, 203),
+ (128, 217),
+ (113, 211),
+ (108, 256),
+ (256, 512),
+ (1024, 1024),
+ (1023, 1024),
+ (2048, 2048),
+ ]
+ )
+ num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
+ h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
+ if major < 5 or (major == 5 and minor < 3):
+ return
+ print("------- MEMORY EFFICIENT ATTENTION ---------")
+ os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
+ for b in batches:
+ for s, s2 in seqs:
+ for n, n2 in num_h:
+ for h in h_sizes:
+ for causal in [True, False]:
+ config = Config(b, s, s2, 0, n, n2, h)
+ parity_check_gqa_no_past(config, causal=causal)
+ if major < 8 or platform.system() != "Linux":
+ return
+ print("------- FLASH ATTENTION --------")
+ os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
+ for b in batches:
+ for s, s2 in seqs:
+ for n, n2 in num_h:
+ for h in h_sizes:
+ for causal in [True, False]:
+ config = Config(b, s, s2, 0, n, n2, h)
+ parity_check_gqa_no_past(config, causal=causal)
+
+ def test_gqa_past(self):
+ if not torch.cuda.is_available():
+ return
+ major, minor = torch.cuda.get_device_capability()
+ if major < 5 or (major == 5 and minor < 3):
+ return
+ os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
+ print("-------- TEST GQA PAST ---------")
+ print("-------- MEMORY EFFICEINT --------")
+ batches = [2] if pipeline_mode else [1, 2]
+ seqs = (
+ [(1, 128), (3, 1024), (64, 2048)]
+ if pipeline_mode
+ else [
+ (1, 128),
+ (1, 339),
+ (3, 1024),
+ (64, 800),
+ (64, 256),
+ (3, 799),
+ (64, 2048),
+ (16, 20000),
+ (1, 128 * 512),
+ (16, 128 * 512),
+ (128, 128),
+ ]
+ )
+ num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
+ h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
+ random.seed(69)
+ for b in batches:
+ for s, s2 in seqs:
+ for n, n2 in num_h:
+ for h in h_sizes:
+ for causal in [True]:
+ for past_kv_format in [Formats.BNSH, Formats.BSNH]:
+ sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
+ config = Config(b, s, s2, sp, n, n2, h)
+ parity_check_gqa_past(
+ config,
+ causal=causal,
+ past_format=past_kv_format,
+ rtol=1e-3,
+ atol=1e-3,
+ )
+ parity_check_gqa_past_no_buff(
+ config,
+ causal=causal,
+ past_format=past_kv_format,
+ rtol=1e-3,
+ atol=1e-3,
+ )
+ if major < 8 or platform.system() != "Linux":
+ return
+ print("------- FLASH ATTENTION -------")
+ os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
+ for b in batches:
+ for s, s2 in seqs:
+ for n, n2 in num_h:
+ for h in h_sizes:
+ for causal in [True]:
+ for past_kv_format in [Formats.BNSH, Formats.BSNH]:
+ sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
+ config = Config(b, s, s2, sp, n, n2, h)
+ parity_check_gqa_past(
+ config,
+ causal=causal,
+ past_format=past_kv_format,
+ rtol=1e-3,
+ atol=1e-3,
+ )
+ parity_check_gqa_past_no_buff(
+ config,
+ causal=causal,
+ past_format=past_kv_format,
+ rtol=1e-3,
+ atol=1e-3,
+ )
+
+
if __name__ == "__main__":
- print("-------- TEST PACKED MHA ---------")
- for b in [5]:
- for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]:
- for n in [6]:
- for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]:
- config = Config(b, s, s, 0, n, n, h)
- parity_check_mha(config, True)
- print("-------- TEST MHA ---------")
- for b in [5]:
- for s, s2 in [
- (113, 203),
- (128, 217),
- (113, 211),
- (108, 256),
- (256, 512),
- (512, 256),
- (1024, 1024),
- (1023, 1024),
- (1024, 1023),
- (2048, 2048),
- ]:
- for n in [6]:
- for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]:
- config = Config(b, s, s2, 0, n, n, h)
- parity_check_mha(config, False)
- print("-------- TEST GQA ---------")
- for b in [5]:
- for s, s2 in [
- (113, 203),
- (128, 217),
- (113, 211),
- (108, 256),
- (256, 512),
- (512, 256),
- (1024, 1024),
- (1023, 1024),
- (1024, 1023),
- (2048, 2048),
- ]:
- for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
- for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
- for causal in [True, False]:
- config = Config(b, s, s2, 0, n, n2, h)
- parity_check_gqa_no_past(config, causal=causal)
- print("-------- TEST GQA PAST ---------")
- random.seed(69)
- for b in [2]:
- for s, s2 in [
- (1, 128),
- (1, 339),
- (3, 1024),
- (64, 800),
- (64, 256),
- (3, 799),
- (64, 2048),
- (16, 20000),
- (1, 128 * 512),
- (16, 128 * 512),
- (128, 128),
- ]:
- for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
- for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
- for causal in [True]:
- for past_kv_format in [Formats.BNSH, Formats.BSNH]:
- sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
- config = Config(b, s, s2, sp, n, n2, h)
- parity_check_gqa_past(
- config,
- causal=causal,
- past_format=past_kv_format,
- rtol=1e-3,
- atol=1e-3,
- )
- parity_check_gqa_past_no_buff(
- config,
- causal=causal,
- past_format=past_kv_format,
- rtol=1e-3,
- atol=1e-3,
- )
+ unittest.main()