From db307f363217b5d4678ee0e9ae7a13b8880a563c Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Fri, 27 Oct 2023 08:47:53 -0700 Subject: [PATCH] clarify input and output formats memory efficient attention --- onnxruntime/contrib_ops/cuda/bert/attention_impl.cu | 2 +- .../cuda/bert/cutlass_fmha/fmha_launch_template.h | 6 +++--- .../cuda/bert/cutlass_fmha/memory_efficient_attention.h | 3 ++- .../contrib_ops/cuda/bert/group_query_attention_impl.cu | 2 +- onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu | 2 +- .../cuda/bert/packed_multihead_attention_impl.cu | 2 +- 6 files changed, 9 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 387b289acb758..2c2bd97263f0b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -393,7 +393,7 @@ Status EfficientAttention( p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; - p.is_bsnh = true; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) ? data.scratch : nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 4e8facb37233e..84d979f6f2569 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -54,8 +54,8 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.custom_mask_type = Attention::CausalFromBottomRight; } - if (params.is_bsnh) { - // Input format is BxSxNxH, output is BxSxNxH + if (params.is_kv_bsnh) { + // Input Q, K, V format is BxSxNxH, output is BxSxNxH p.q_strideH = params.qk_head_size; p.k_strideH = params.qk_head_size; p.v_strideH = params.v_head_size; @@ -72,7 +72,7 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.v_strideB = static_cast(p.v_strideM) * params.kv_sequence_length; p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; } else { - // Input format is BxNxSxH, output is BxNxSxH + // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH p.q_strideH = params.qk_head_size; p.k_strideH = params.kv_sequence_length * params.qk_head_size; p.v_strideH = params.kv_sequence_length * params.v_head_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index eab1b0047cc8a..f16567bb6f2b7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -14,11 +14,12 @@ namespace cuda { struct MemoryEfficientAttentionParams { int32_t sm; bool is_half; - bool is_bsnh = true; + bool is_kv_bsnh = true; int32_t batch_size; int32_t num_heads; int32_t sequence_length; int32_t kv_sequence_length; + int32_t max_sequence_length; int32_t qk_head_size; int32_t v_head_size; bool causal; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 57838edbaa7a4..8fa6a445de044 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -570,7 +570,7 @@ Status EfficientAttention( p.value = value; p.attn_bias = nullptr; p.is_attn_bias_batched = false; - p.is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) ? data.fmha_buffer diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index c27ba47afed5e..a0532ddd77cc6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -507,7 +507,7 @@ Status FusedScaledDotProductAttentionCutlass( MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; - p.is_bsnh = true; + p.is_kv_bsnh = true; p.batch_size = parameters.batch_size; p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index cfbdf0959cdf5..12fab94b4bbba 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -702,7 +702,7 @@ Status FusedAttentionCutlass( p.attn_bias = data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; - p.is_bsnh = true; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v))) : nullptr;