Skip to content

Commit

Permalink
Fix scale order; revert doc change
Browse files Browse the repository at this point in the history
  • Loading branch information
mindest committed Oct 23, 2024
1 parent 9df4782 commit 50dfa85
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, num_heads, v_hidden_size)</dd>
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
<dt><tt>present_key</tt> (optional) : T</dt>
<dd>present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dt><tt>present_value</tt> (optional) : T</dt>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
T* attention_probs_ptr = reinterpret_cast<T*>(attention_probs) + last_offset;
math::Dot<float, CPUMathUtil>(head_size, q_vec, K + i * head_size, attention_probs_ptr, nullptr);

*attention_probs_ptr *= scale;
// Apply the attention bias and mask
if (attn_bias_data != nullptr) {
*attention_probs_ptr += attn_bias_data[attn_bias_base_offset + past_sequence_length];
Expand All @@ -348,7 +349,6 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
if (is_masked) {
*attention_probs_ptr += mask_filter_value_;
}
*attention_probs_ptr *= scale;
}

{
Expand All @@ -362,6 +362,8 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
const T* past_k_vec = past_key_data + beam_batch_offset + beam_offset + j * head_size;
T* output = reinterpret_cast<T*>(attention_probs) + j + i * probs_matrix_size;
math::Dot<float, CPUMathUtil>(head_size, q_vec, past_k_vec, output, nullptr);

*output *= scale;
// Apply the attention bias and mask
if (attn_bias_data != nullptr) {
*output += attn_bias_data[attn_bias_base_offset + j];
Expand All @@ -371,7 +373,6 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
if (is_masked) {
*output += mask_filter_value_;
}
*output *= scale;
}
}
// Append current key to present key (past_present_share_buffer_ is true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC
const Tensor* cache_indir,
OpKernelContext* context,
int beam_width,
Tensor* scaled_qk = nullptr) const;
Tensor* output_qk = nullptr) const;
void ComputeAttentionProbsWithBeams(T* attention_probs,
const T* Q,
const T* K,
Expand All @@ -50,7 +50,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC
bool broadcast_attn_bias_dim_1,
const int32_t* cache_indir_data,
int beam_width,
T* scaled_qk_data = nullptr) const;
T* output_qk_data = nullptr) const;
void ComputeVxAttentionScoreWithBeams(T* output,
T* tmp_buffer,
const T* attention_probs,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, num_heads, v_hidden_size)",
"3D output tensor with shape (batch_size, sequence_length, v_hidden_size)",
"T")
.Output(1,
"present_key",
Expand Down

0 comments on commit 50dfa85

Please sign in to comment.