-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DMMHA: add unit tests; fix CPU, CUDA kernel #22567
Conversation
When the PR is ready could you please update the PR title and description to better reflect the problem and fix ? Thanks. |
onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
Fixed
Show fixed
Hide fixed
reduce error tolerance; make ToFloat (float) constexpr
@hariharans29, is it true that for cross attention CUDA kernel, the key layout is also reordered as Lines 105 to 107 in dd28f09
instead of BNSH? onnxruntime/onnxruntime/core/graph/contrib_ops/bert_defs.cc Lines 855 to 860 in dd28f09
|
onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
Fixed
Show fixed
Hide fixed
onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
Fixed
Show fixed
Hide fixed
onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
Fixed
Show fixed
Hide fixed
onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
Fixed
Show fixed
Hide fixed
...ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
Show resolved
Hide resolved
/azp run Big Models,Linux Android Emulator QNN CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline |
Azure Pipelines successfully started running 9 pipeline(s). |
/azp run ONNX Runtime Web CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows GPU TensorRT CI Pipeline,Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 8 pipeline(s). |
...ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
Show resolved
Hide resolved
For self-attention, onnxruntime/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc Lines 150 to 202 in dd28f09
I believe onnxruntime/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h Lines 341 to 352 in 03ea5dc
|
@kunal-vaishnavi, what's the reason that need cross attention in this op? I think cross attention shall be supported by MHA, and this op is only for decoding only. That could make logic more clearly. |
To #22567 (comment): Makes sense to me, in the tests I also use reordered key as input for cross attention. Maybe I should also add some comments in onnxruntime/onnxruntime/core/graph/contrib_ops/bert_defs.cc Lines 855 to 860 in dd28f09
|
Thanks @tianleiwu @kunal-vaishnavi for the review! Is this PR ready to merge, if I keep the following changes in another PR?
|
Whisper uses alternating layers of self-attention and cross-attention during decoding. |
...ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
Show resolved
Hide resolved
Thanks @tianleiwu, @kunal-vaishnavi, @hariharans29! |
### Description Fixes: (1) cpu kernel: applying scale before bias and mask like other MHA ops (2) cpu kernel: correct offset during appending past to present. (3) cuda kernel: apply mask if provided; fix output_qk offset. Add DMMHA unit tests
### Description Fixes: (1) cpu kernel: applying scale before bias and mask like other MHA ops (2) cpu kernel: correct offset during appending past to present. (3) cuda kernel: apply mask if provided; fix output_qk offset. Add DMMHA unit tests
### Description Fixes: (1) cpu kernel: applying scale before bias and mask like other MHA ops (2) cpu kernel: correct offset during appending past to present. (3) cuda kernel: apply mask if provided; fix output_qk offset. Add DMMHA unit tests
### Description Fixes: (1) cpu kernel: applying scale before bias and mask like other MHA ops (2) cpu kernel: correct offset during appending past to present. (3) cuda kernel: apply mask if provided; fix output_qk offset. Add DMMHA unit tests
Description
Fixes:
(1) cpu kernel: applying scale before bias and mask like other MHA ops
(2) cpu kernel: correct offset during appending past to present.
(3) cuda kernel: apply mask if provided; fix output_qk offset.
Add DMMHA unit tests