Skip to content

Commit

Permalink
DMMHA: add unit tests; fix CPU, CUDA kernel (#22567)
Browse files Browse the repository at this point in the history
### Description

Fixes:
(1) cpu kernel: applying scale before bias and mask like other MHA ops
(2) cpu kernel: correct offset during appending past to present.
(3) cuda kernel: apply mask if provided; fix output_qk offset.

Add DMMHA unit tests
  • Loading branch information
mindest authored Nov 2, 2024
1 parent 2e4e221 commit 4ffc1ff
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 367 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class AttentionCPUBase : public AttentionBase {
// Convert mask from boolean (0/1) to float (mask_filter_value/0.0f).
// Merge padding mask with causal mask, and broadcast to 3D (BxSxT).
PrepareMask(mask_index_data, mask_index_dims, static_cast<T*>(mask_data),
causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
causal, batch_size, sequence_length, kv_sequence_length, past_sequence_length, mask_filter_value_);
DUMP_CPU_TENSOR("Mask3D", static_cast<T*>(mask_data), batch_size, sequence_length, total_sequence_length);
}

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ void PrepareMask(const int32_t* mask_index,
bool causal,
int batch_size,
int sequence_length,
int kv_sequence_length,
int past_sequence_length,
float mask_filter_value) {
const int all_sequence_length = past_sequence_length + sequence_length;
const int all_sequence_length = past_sequence_length + kv_sequence_length;

// mask_data has been filled with 0, and its shape is BxSxT
T* p_mask = mask_data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
T* attention_probs_ptr = reinterpret_cast<T*>(attention_probs) + last_offset;
math::Dot<float, CPUMathUtil>(head_size, q_vec, K + i * head_size, attention_probs_ptr, nullptr);

*attention_probs_ptr *= scale;
// Apply the attention bias and mask
if (attn_bias_data != nullptr) {
*attention_probs_ptr += attn_bias_data[attn_bias_base_offset + past_sequence_length];
Expand All @@ -348,7 +349,6 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
if (is_masked) {
*attention_probs_ptr += mask_filter_value_;
}
*attention_probs_ptr *= scale;
}

{
Expand All @@ -362,6 +362,8 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
const T* past_k_vec = past_key_data + beam_batch_offset + beam_offset + j * head_size;
T* output = reinterpret_cast<T*>(attention_probs) + j + i * probs_matrix_size;
math::Dot<float, CPUMathUtil>(head_size, q_vec, past_k_vec, output, nullptr);

*output *= scale;
// Apply the attention bias and mask
if (attn_bias_data != nullptr) {
*output += attn_bias_data[attn_bias_base_offset + j];
Expand All @@ -371,11 +373,11 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
if (is_masked) {
*output += mask_filter_value_;
}
*output *= scale;
}
}
// Append current key to present key (past_present_share_buffer_ is true)
memcpy(present_key_data + i * max_sequence_length * head_size, K + i * head_size, head_size * sizeof(T));
memcpy(present_key_data + (i * max_sequence_length + past_sequence_length) * head_size,
K + i * head_size, head_size * sizeof(T));
}
});

Expand Down Expand Up @@ -460,7 +462,7 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeVxAttentionScoreWithBeams(
}
}
// Append current value to present value (past_present_share_buffer_ is true)
memcpy(present_value_data + i * max_sequence_length * v_head_size,
memcpy(present_value_data + (i * max_sequence_length + past_sequence_length) * v_head_size,
V + i * v_head_size,
v_head_size * sizeof(T));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC
const Tensor* cache_indir,
OpKernelContext* context,
int beam_width,
Tensor* scaled_qk = nullptr) const;
Tensor* output_qk = nullptr) const;
void ComputeAttentionProbsWithBeams(T* attention_probs,
const T* Q,
const T* K,
Expand All @@ -50,7 +50,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC
bool broadcast_attn_bias_dim_1,
const int32_t* cache_indir_data,
int beam_width,
T* scaled_qk_data = nullptr) const;
T* output_qk_data = nullptr) const;
void ComputeVxAttentionScoreWithBeams(T* output,
T* tmp_buffer,
const T* attention_probs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
if (params.attention_bias != nullptr) {
qk = add_vec(qk, reinterpret_cast<T*>(params.attention_bias)[attn_bias_offset + tlength]);
}
if (params.mask != nullptr && params.mask[bi_total_seq_length + params.past_sequence_length] == 0) {
qk += params.mask_filter_value;
}
qk_max = qk;
qk_smem[tlength] = qk;
}
Expand Down Expand Up @@ -534,7 +537,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio

if (params.out_qk != nullptr) {
// store cross qk before softmax, out_qk has shape [B(batchxbeam), #Head, 1, total_sequence_length]
float* target = ((float*)params.out_qk) + ((int64_t)bhi * tlength);
float* target = (reinterpret_cast<float*>(params.out_qk)) + (static_cast<int64_t>(bhi) * (sum_tlength + 1));
for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) {
target[ti] = (float)(qk_smem[ti]);
}
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema::Optional)
.Input(9,
"cache_indirection",
// This input is useful for CUDA EP only.
"A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies "
"which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration",
"M",
Expand Down
Loading

0 comments on commit 4ffc1ff

Please sign in to comment.