Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DMMHA: add unit tests; fix CPU, CUDA kernel #22567

Merged
merged 16 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
<dd>3D output tensor with shape (batch_size, num_heads, v_hidden_size)</dd>
mindest marked this conversation as resolved.
Show resolved Hide resolved
<dt><tt>present_key</tt> (optional) : T</dt>
<dd>present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dt><tt>present_value</tt> (optional) : T</dt>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
}
}
// Append current key to present key (past_present_share_buffer_ is true)
memcpy(present_key_data + i * max_sequence_length * head_size, K + i * head_size, head_size * sizeof(T));
memcpy(present_key_data + (i * max_sequence_length + past_sequence_length) * head_size,
K + i * head_size, head_size * sizeof(T));
}
});

Expand Down Expand Up @@ -460,7 +461,7 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeVxAttentionScoreWithBeams(
}
}
// Append current value to present value (past_present_share_buffer_ is true)
memcpy(present_value_data + i * max_sequence_length * v_head_size,
memcpy(present_value_data + (i * max_sequence_length + past_sequence_length) * v_head_size,
V + i * v_head_size,
v_head_size * sizeof(T));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio

if (params.out_qk != nullptr) {
// store cross qk before softmax, out_qk has shape [B(batchxbeam), #Head, 1, total_sequence_length]
float* target = ((float*)params.out_qk) + ((int64_t)bhi * tlength);
float* target = ((float*)params.out_qk) + ((int64_t)bhi * (tlength + 1));
for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) {
target[ti] = (float)(qk_smem[ti]);
}
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema::Optional)
.Input(9,
"cache_indirection",
// This input is useful for CUDA EP only.
"A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies "
"which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration",
"M",
Expand All @@ -920,7 +919,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, v_hidden_size)",
"3D output tensor with shape (batch_size, num_heads, v_hidden_size)",
"T")
.Output(1,
"present_key",
Expand Down
Loading
Loading