Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enabling support for QAttention (#18326)
[Cherry Pick Reviewed] #16837 #16851 #17947 ### Description Enabling support for `Past`, `Present` and `unidirectional` for [QAttention](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QAttention) Contrib Op ``` Note: Google Test filter = *QAttention* [==========] Running 14 tests from 2 test suites. [----------] Global test environment set-up. [----------] 1 test from CPU_U8S8_Precision_Tests [ RUN ] CPU_U8S8_Precision_Tests.QAttention [ OK ] CPU_U8S8_Precision_Tests.QAttention (104 ms) [----------] 1 test from CPU_U8S8_Precision_Tests (105 ms total) [----------] 13 tests from QAttentionTest [ RUN ] QAttentionTest.QAttentionBatch1 [ OK ] QAttentionTest.QAttentionBatch1 (255 ms) [ RUN ] QAttentionTest.QAttentionBatch1_Float16 [ OK ] QAttentionTest.QAttentionBatch1_Float16 (0 ms) [ RUN ] QAttentionTest.QAttentionBatch2 [ OK ] QAttentionTest.QAttentionBatch2 (201 ms) [ RUN ] QAttentionTest.QAttentionMaskPartialSequence [ OK ] QAttentionTest.QAttentionMaskPartialSequence (197 ms) [ RUN ] QAttentionTest.QAttentionMaskExceedSequence [ OK ] QAttentionTest.QAttentionMaskExceedSequence (192 ms) [ RUN ] QAttentionTest.QAttentionNoMaskIndex [ OK ] QAttentionTest.QAttentionNoMaskIndex (186 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_U8U8 [ OK ] QAttentionTest.QAttentionUnidirectional_U8U8 (9 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_U8S8 [ OK ] QAttentionTest.QAttentionUnidirectional_U8S8 (9 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_CUDA [ OK ] QAttentionTest.QAttentionUnidirectional_CUDA (0 ms) [ RUN ] QAttentionTest.QAttentionPastState_u8u8 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.75743968039751053, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.67312467098236084, cur_actual[i] evaluates to -0.084315009415149689, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.75743968039751053, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.67312467098236084, cur_actual[i] evaluates to -0.084315009415149689, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.03001787792891264, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to -0.021467097103595734, cur_actual[i] evaluates to 0.008550780825316906, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.03001787792891264, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to -0.021467097103595734, cur_actual[i] evaluates to 0.008550780825316906, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 [ FAILED ] QAttentionTest.QAttentionPastState_u8u8 (2067 ms) [ RUN ] QAttentionTest.QAttentionPastState_u8s8 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.74043640494346619, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.65650326013565063, cur_actual[i] evaluates to -0.083933144807815552, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.081788420677185059, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 1.0076344013214111, cur_actual[i] evaluates to 1.0894228219985962, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:965 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.74043640494346619, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.65650326013565063, cur_actual[i] evaluates to -0.083933144807815552, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.081788420677185059, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 1.0076344013214111, cur_actual[i] evaluates to 1.0894228219985962, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:965 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.024714200757443905, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to -0.016048312187194824, cur_actual[i] evaluates to 0.0086658885702490807, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.0092324763536453247, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.24175386130809784, cur_actual[i] evaluates to 0.25098633766174316, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:979 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.024714200757443905, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to -0.016048312187194824, cur_actual[i] evaluates to 0.0086658885702490807, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.0092324763536453247, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.24175386130809784, cur_actual[i] evaluates to 0.25098633766174316, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:979 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 [ FAILED ] QAttentionTest.QAttentionPastState_u8s8 (2079 ms) [ RUN ] QAttentionTest.QAttentionPrunedModel [ OK ] QAttentionTest.QAttentionPrunedModel (206 ms) [ RUN ] QAttentionTest.SharedPrepackedWeights [ OK ] QAttentionTest.SharedPrepackedWeights (79 ms) [----------] 13 tests from QAttentionTest (5492 ms total) [----------] Global test environment tear-down [==========] 14 tests from 2 test suites ran. (5600 ms total) [ PASSED ] 12 tests. [ FAILED ] 2 tests, listed below: [ FAILED ] QAttentionTest.QAttentionPastState_u8u8 [ FAILED ] QAttentionTest.QAttentionPastState_u8s8 2 FAILED TESTS memleakdbg: ----- No memory leaks detected ----- ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Xiang Zhang <[email protected]>
- Loading branch information