Skip to content

Commit

Permalink
clarify input and output formats memory efficient attention
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Oct 27, 2023
1 parent 6c6aead commit db307f3
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ Status EfficientAttention(
p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
p.output = data.output;
p.is_bsnh = true;
p.is_kv_bsnh = true;
p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float))
? data.scratch
: nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.custom_mask_type = Attention::CausalFromBottomRight;
}

if (params.is_bsnh) {
// Input format is BxSxNxH, output is BxSxNxH
if (params.is_kv_bsnh) {
// Input Q, K, V format is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.qk_head_size;
p.v_strideH = params.v_head_size;
Expand All @@ -72,7 +72,7 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.kv_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
} else {
// Input format is BxNxSxH, output is BxNxSxH
// Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.kv_sequence_length * params.qk_head_size;
p.v_strideH = params.kv_sequence_length * params.v_head_size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ namespace cuda {
struct MemoryEfficientAttentionParams {
int32_t sm;
bool is_half;
bool is_bsnh = true;
bool is_kv_bsnh = true;
int32_t batch_size;
int32_t num_heads;
int32_t sequence_length;
int32_t kv_sequence_length;
int32_t max_sequence_length;
int32_t qk_head_size;
int32_t v_head_size;
bool causal;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ Status EfficientAttention(
p.value = value;
p.attn_bias = nullptr;
p.is_attn_bias_batched = false;
p.is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
p.output = data.output;
p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float))
? data.fmha_buffer
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ Status FusedScaledDotProductAttentionCutlass(
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
p.is_half = sizeof(T) == 2;
p.is_bsnh = true;
p.is_kv_bsnh = true;
p.batch_size = parameters.batch_size;
p.num_heads = parameters.num_heads;
p.sequence_length = parameters.sequence_length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ Status FusedAttentionCutlass(
p.attn_bias = data.relative_position_bias;
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
p.output = data.output;
p.is_bsnh = true;
p.is_kv_bsnh = true;
p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float))
? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v)))
: nullptr;
Expand Down

0 comments on commit db307f3

Please sign in to comment.