Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DMMHA: add unit tests; fix CPU, CUDA kernel #22567

Merged
merged 16 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class AttentionCPUBase : public AttentionBase {
// Convert mask from boolean (0/1) to float (mask_filter_value/0.0f).
// Merge padding mask with causal mask, and broadcast to 3D (BxSxT).
PrepareMask(mask_index_data, mask_index_dims, static_cast<T*>(mask_data),
causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
causal, batch_size, sequence_length, kv_sequence_length, past_sequence_length, mask_filter_value_);
DUMP_CPU_TENSOR("Mask3D", static_cast<T*>(mask_data), batch_size, sequence_length, total_sequence_length);
}

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@
bool causal,
int batch_size,
int sequence_length,
int kv_sequence_length,

Check warning on line 123 in onnxruntime/contrib_ops/cpu/bert/attention_helper.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_helper.h:123: Do not indent within a namespace. [whitespace/indent_namespace] [4]
int past_sequence_length,
float mask_filter_value) {
const int all_sequence_length = past_sequence_length + sequence_length;
const int all_sequence_length = past_sequence_length + kv_sequence_length;

// mask_data has been filled with 0, and its shape is BxSxT
T* p_mask = mask_data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
T* attention_probs_ptr = reinterpret_cast<T*>(attention_probs) + last_offset;
math::Dot<float, CPUMathUtil>(head_size, q_vec, K + i * head_size, attention_probs_ptr, nullptr);

*attention_probs_ptr *= scale;
// Apply the attention bias and mask
if (attn_bias_data != nullptr) {
*attention_probs_ptr += attn_bias_data[attn_bias_base_offset + past_sequence_length];
Expand All @@ -348,7 +349,6 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
if (is_masked) {
*attention_probs_ptr += mask_filter_value_;
}
*attention_probs_ptr *= scale;
}

{
Expand All @@ -362,6 +362,8 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
const T* past_k_vec = past_key_data + beam_batch_offset + beam_offset + j * head_size;
T* output = reinterpret_cast<T*>(attention_probs) + j + i * probs_matrix_size;
math::Dot<float, CPUMathUtil>(head_size, q_vec, past_k_vec, output, nullptr);

*output *= scale;
// Apply the attention bias and mask
if (attn_bias_data != nullptr) {
*output += attn_bias_data[attn_bias_base_offset + j];
Expand All @@ -371,11 +373,11 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
if (is_masked) {
*output += mask_filter_value_;
}
*output *= scale;
}
}
// Append current key to present key (past_present_share_buffer_ is true)
memcpy(present_key_data + i * max_sequence_length * head_size, K + i * head_size, head_size * sizeof(T));
memcpy(present_key_data + (i * max_sequence_length + past_sequence_length) * head_size,
K + i * head_size, head_size * sizeof(T));
}
});

Expand Down Expand Up @@ -460,7 +462,7 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeVxAttentionScoreWithBeams(
}
}
// Append current value to present value (past_present_share_buffer_ is true)
memcpy(present_value_data + i * max_sequence_length * v_head_size,
memcpy(present_value_data + (i * max_sequence_length + past_sequence_length) * v_head_size,
V + i * v_head_size,
v_head_size * sizeof(T));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC
const Tensor* cache_indir,
OpKernelContext* context,
int beam_width,
Tensor* scaled_qk = nullptr) const;
Tensor* output_qk = nullptr) const;
void ComputeAttentionProbsWithBeams(T* attention_probs,
const T* Q,
const T* K,
Expand All @@ -50,7 +50,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC
bool broadcast_attn_bias_dim_1,
const int32_t* cache_indir_data,
int beam_width,
T* scaled_qk_data = nullptr) const;
T* output_qk_data = nullptr) const;
void ComputeVxAttentionScoreWithBeams(T* output,
T* tmp_buffer,
const T* attention_probs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@
if (params.attention_bias != nullptr) {
qk = add_vec(qk, reinterpret_cast<T*>(params.attention_bias)[attn_bias_offset + tlength]);
}
if (params.mask != nullptr && params.mask[bi_total_seq_length + params.past_sequence_length] == 0) {
qk += params.mask_filter_value;
mindest marked this conversation as resolved.
Show resolved Hide resolved
}
qk_max = qk;
qk_smem[tlength] = qk;
}
Expand Down Expand Up @@ -534,7 +537,7 @@

if (params.out_qk != nullptr) {
// store cross qk before softmax, out_qk has shape [B(batchxbeam), #Head, 1, total_sequence_length]
float* target = ((float*)params.out_qk) + ((int64_t)bhi * tlength);
float* target = ((float*)params.out_qk) + ((int64_t)bhi * (sum_tlength + 1));

Check warning on line 540 in onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<int64_t>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu:540: Using C-style cast. Use static_cast<int64_t>(...) instead [readability/casting] [4]

Check warning on line 540 in onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu:540: Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4]
for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) {
target[ti] = (float)(qk_smem[ti]);
}
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema::Optional)
.Input(9,
"cache_indirection",
// This input is useful for CUDA EP only.
"A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies "
"which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration",
"M",
Expand Down
Loading
Loading