Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DMLEP QAttention update causal (#19533)
### Description Bug Fix for `QAttentionTest.QAttentionPastState*` test failures The change generates causal mask which is an upper Triangular Boolean Matrix as input to MHA mask. DML internally adds `maskFilterValue` to the "off" bits in the mask and sets the "on" bits to 0. ``` Note: Google Test filter = *QAttention* [==========] Running 14 tests from 2 test suites. [----------] Global test environment set-up. [----------] 1 test from CPU_U8S8_Precision_Tests [ RUN ] CPU_U8S8_Precision_Tests.QAttention [ OK ] CPU_U8S8_Precision_Tests.QAttention (124 ms) [----------] 1 test from CPU_U8S8_Precision_Tests (124 ms total) [----------] 13 tests from QAttentionTest [ RUN ] QAttentionTest.QAttentionBatch1 [ OK ] QAttentionTest.QAttentionBatch1 (531 ms) [ RUN ] QAttentionTest.QAttentionBatch1_Float16 [ OK ] QAttentionTest.QAttentionBatch1_Float16 (0 ms) [ RUN ] QAttentionTest.QAttentionBatch2 [ OK ] QAttentionTest.QAttentionBatch2 (441 ms) [ RUN ] QAttentionTest.QAttentionMaskPartialSequence [ OK ] QAttentionTest.QAttentionMaskPartialSequence (410 ms) [ RUN ] QAttentionTest.QAttentionMaskExceedSequence [ OK ] QAttentionTest.QAttentionMaskExceedSequence (398 ms) [ RUN ] QAttentionTest.QAttentionNoMaskIndex [ OK ] QAttentionTest.QAttentionNoMaskIndex (389 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_U8U8 [ OK ] QAttentionTest.QAttentionUnidirectional_U8U8 (11 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_U8S8 [ OK ] QAttentionTest.QAttentionUnidirectional_U8S8 (10 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_CUDA [ OK ] QAttentionTest.QAttentionUnidirectional_CUDA (0 ms) [ RUN ] QAttentionTest.QAttentionPastState_u8u8 [ OK ] QAttentionTest.QAttentionPastState_u8u8 (2683 ms) [ RUN ] QAttentionTest.QAttentionPastState_u8s8 [ OK ] QAttentionTest.QAttentionPastState_u8s8 (2674 ms) [ RUN ] QAttentionTest.QAttentionPrunedModel [ OK ] QAttentionTest.QAttentionPrunedModel (399 ms) [ RUN ] QAttentionTest.SharedPrepackedWeights [ OK ] QAttentionTest.SharedPrepackedWeights (89 ms) [----------] 13 tests from QAttentionTest (8047 ms total) [----------] Global test environment tear-down [==========] 14 tests from 2 test suites ran. (8175 ms total) [ PASSED ] 14 tests. memleakdbg: ----- No memory leaks detected ----- ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
- Loading branch information