Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DMMHA: add unit tests; fix CPU, CUDA kernel #22567

Merged
merged 16 commits into from
Nov 2, 2024
Merged

DMMHA: add unit tests; fix CPU, CUDA kernel #22567

merged 16 commits into from
Nov 2, 2024

Conversation

mindest
Copy link
Contributor

@mindest mindest commented Oct 23, 2024

Description

Fixes:
(1) cpu kernel: applying scale before bias and mask like other MHA ops
(2) cpu kernel: correct offset during appending past to present.
(3) cuda kernel: apply mask if provided; fix output_qk offset.

Add DMMHA unit tests

docs/ContribOperators.md Outdated Show resolved Hide resolved
@hariharans29
Copy link
Member

When the PR is ready could you please update the PR title and description to better reflect the problem and fix ? Thanks.

@mindest mindest changed the title [WIP] Add DMMHA unit tests with fix DMMHA: add unit tests; fix CPU, CUDA kernel Oct 28, 2024
@mindest
Copy link
Contributor Author

mindest commented Oct 28, 2024

@hariharans29, is it true that for cross attention CUDA kernel, the key layout is also reordered as [B, H, head_size/x, L, x],

// The layout of the cache is [B, H, head_size/x, L, x] with x == 4/8/16 for FP32/FP16/FP8. Since each thread
// owns x elements, we have to decompose the linear index into chunks of x values and the posi-
// tion of the thread in that chunk.

instead of BNSH?

.Input(1,
"key",
"Key with shape (batch_size, 1, hidden_size) for self attention "
"or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention",
"T",
OpSchema::Optional)

@mindest mindest marked this pull request as ready for review October 28, 2024 17:48
@mindest
Copy link
Contributor Author

mindest commented Oct 29, 2024

/azp run Big Models,Linux Android Emulator QNN CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline

Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@mindest
Copy link
Contributor Author

mindest commented Oct 29, 2024

/azp run ONNX Runtime Web CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows GPU TensorRT CI Pipeline,Windows x64 QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 8 pipeline(s).

@kunal-vaishnavi
Copy link
Contributor

@hariharans29, is it true that for cross attention CUDA kernel, the key layout is also reordered as [B, H, head_size/x, L, x],

// The layout of the cache is [B, H, head_size/x, L, x] with x == 4/8/16 for FP32/FP16/FP8. Since each thread
// owns x elements, we have to decompose the linear index into chunks of x values and the posi-
// tion of the thread in that chunk.

instead of BNSH?

.Input(1,
"key",
"Key with shape (batch_size, 1, hidden_size) for self attention "
"or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention",
"T",
OpSchema::Optional)

For self-attention, parameters.k_cache = present_key_data = past_key_data since past_present_share_buffer = true. For cross-attention, parameters.k_cache = key_data.

if (past_key == nullptr && present_key == nullptr) {
if (attention_bias != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"DecoderMaskedMultiHeadAttention does not support attention bias for cross-attention");
}
parameters.is_cross_attention = true;
parameters.total_sequence_length = parameters.kv_sequence_length;
parameters.max_sequence_length = parameters.kv_sequence_length;
// parameters.k and parameters.v are nullptr
parameters.k_cache = const_cast<T1*>(key->Data<T1>());
parameters.v_cache = const_cast<T1*>(value->Data<T1>());
parameters.k_bias = nullptr;
parameters.v_bias = nullptr;
} else {
// Sanity check
ORT_ENFORCE(past_present_share_buffer_);
ORT_ENFORCE(past_key != nullptr && past_value != nullptr);
auto* present_key_data = present_key->MutableData<T1>();
auto* present_value_data = present_value->MutableData<T1>();
auto* past_key_data = past_key->Data<T1>();
auto* past_value_data = past_value->Data<T1>();
// No production use-case will incur this copy cost as the implementation of
// GreedySearch/BeamSearch is written in such a way that the past and present buffers
// will be shared.
// This is just to circumvent the OpTester's limitation of not being able to bind a specific
// buffer to inputs/outputs.
if (present_key_data != past_key_data) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(present_key_data, past_key_data, past_key->SizeInBytes(),
cudaMemcpyDeviceToDevice, cuda_stream));
}
if (present_value_data != past_value_data) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(present_value_data, past_value_data, past_value->SizeInBytes(),
cudaMemcpyDeviceToDevice, cuda_stream));
}
parameters.is_cross_attention = false;
bool is_packed_qkv = (key == nullptr && value == nullptr);
parameters.is_packed_qkv = is_packed_qkv;
parameters.k = is_packed_qkv
? const_cast<T1*>(query->Data<T1>() + parameters.hidden_size)
: const_cast<T1*>(key->Data<T1>());
parameters.v = is_packed_qkv
? const_cast<T1*>(query->Data<T1>() + 2 * static_cast<size_t>(parameters.hidden_size))
: const_cast<T1*>(value->Data<T1>());
parameters.k_cache = present_key_data;
parameters.v_cache = present_value_data;
}

I believe parameters.k_cache should be reordered before the kernel is launched so past_key is already reordered for self-attention and key is already reordered for cross-attention. This behavior would match the below comments.

#ifdef USE_CUDA
// Here we only need to reorder the past key for self-attention and cross-attention.
for (size_t i = 0; i < 2 * static_cast<size_t>(decoder_subgraph_.num_layers); ++i) {
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
*decoder_feeds[offset + 2 * i].GetMutable<Tensor>(),
beam_state.staging_for_past_state_reorder,
this->ort_stream_));
}
size_t cache_indir_input_offset = static_cast<size_t>(decoder_subgraph_.GetFirstPastInputIndex()) + 4 * static_cast<size_t>(decoder_subgraph_.num_layers) + 2;
ORT_RETURN_IF_ERROR(init_cache_indir_func_(*decoder_feeds[cache_indir_input_offset].GetMutable<Tensor>(), this->ort_stream_));
#endif
}

@tianleiwu
Copy link
Contributor

tianleiwu commented Oct 31, 2024

@kunal-vaishnavi, what's the reason that need cross attention in this op? I think cross attention shall be supported by MHA, and this op is only for decoding only. That could make logic more clearly.

@mindest
Copy link
Contributor Author

mindest commented Oct 31, 2024

To #22567 (comment): Makes sense to me, in the tests I also use reordered key as input for cross attention. Maybe I should also add some comments in

.Input(1,
"key",
"Key with shape (batch_size, 1, hidden_size) for self attention "
"or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention",
"T",
OpSchema::Optional)
so that it is clear key in cross-attention is also reordered for CUDA EP.

@mindest
Copy link
Contributor Author

mindest commented Oct 31, 2024

Thanks @tianleiwu @kunal-vaishnavi for the review! Is this PR ready to merge, if I keep the following changes in another PR?

  • Support float16 for out_qk
  • Add comments why it is sum_tlength + 1 instead of tlength (for cross-attn total length is kv_seq_len, for self-attn it is past_len + 1)
  • Update schema comments of input 1 key for cross-attn.

@kunal-vaishnavi
Copy link
Contributor

@kunal-vaishnavi, what's the reason that need cross attention in this op? I think cross attention shall be supported by MHA, and this op is only for decoding only. That could make logic more clearly.

Whisper uses alternating layers of self-attention and cross-attention during decoding.

@mindest
Copy link
Contributor Author

mindest commented Nov 2, 2024

Thanks @tianleiwu, @kunal-vaishnavi, @hariharans29!

@mindest mindest merged commit 4ffc1ff into main Nov 2, 2024
74 checks passed
@mindest mindest deleted the linmin/dmmha_test branch November 2, 2024 13:05
ishwar-raut1 pushed a commit to ishwar-raut1/onnxruntime that referenced this pull request Nov 19, 2024
### Description

Fixes:
(1) cpu kernel: applying scale before bias and mask like other MHA ops
(2) cpu kernel: correct offset during appending past to present.
(3) cuda kernel: apply mask if provided; fix output_qk offset.

Add DMMHA unit tests
ankitm3k pushed a commit to intel/onnxruntime that referenced this pull request Dec 11, 2024
### Description

Fixes:
(1) cpu kernel: applying scale before bias and mask like other MHA ops
(2) cpu kernel: correct offset during appending past to present.
(3) cuda kernel: apply mask if provided; fix output_qk offset.

Add DMMHA unit tests
ankitm3k pushed a commit to intel/onnxruntime that referenced this pull request Dec 11, 2024
### Description

Fixes:
(1) cpu kernel: applying scale before bias and mask like other MHA ops
(2) cpu kernel: correct offset during appending past to present.
(3) cuda kernel: apply mask if provided; fix output_qk offset.

Add DMMHA unit tests
ankitm3k pushed a commit to intel/onnxruntime that referenced this pull request Dec 11, 2024
### Description

Fixes:
(1) cpu kernel: applying scale before bias and mask like other MHA ops
(2) cpu kernel: correct offset during appending past to present.
(3) cuda kernel: apply mask if provided; fix output_qk offset.

Add DMMHA unit tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants