Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DMLEP QAttention update causal #19533

Merged
merged 6 commits into from
Mar 4, 2024
Merged

DMLEP QAttention update causal #19533

merged 6 commits into from
Mar 4, 2024

Conversation

raoanag
Copy link
Contributor

@raoanag raoanag commented Feb 15, 2024

Description

Bug Fix for QAttentionTest.QAttentionPastState* test failures

The change generates causal mask which is an upper Triangular Boolean Matrix as input to MHA mask. DML internally adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0.

Note: Google Test filter = *QAttention*
[==========] Running 14 tests from 2 test suites.
[----------] Global test environment set-up.
[----------] 1 test from CPU_U8S8_Precision_Tests
[ RUN      ] CPU_U8S8_Precision_Tests.QAttention
[       OK ] CPU_U8S8_Precision_Tests.QAttention (124 ms)
[----------] 1 test from CPU_U8S8_Precision_Tests (124 ms total)

[----------] 13 tests from QAttentionTest
[ RUN      ] QAttentionTest.QAttentionBatch1
[       OK ] QAttentionTest.QAttentionBatch1 (531 ms)
[ RUN      ] QAttentionTest.QAttentionBatch1_Float16
[       OK ] QAttentionTest.QAttentionBatch1_Float16 (0 ms)
[ RUN      ] QAttentionTest.QAttentionBatch2
[       OK ] QAttentionTest.QAttentionBatch2 (441 ms)
[ RUN      ] QAttentionTest.QAttentionMaskPartialSequence
[       OK ] QAttentionTest.QAttentionMaskPartialSequence (410 ms)
[ RUN      ] QAttentionTest.QAttentionMaskExceedSequence
[       OK ] QAttentionTest.QAttentionMaskExceedSequence (398 ms)
[ RUN      ] QAttentionTest.QAttentionNoMaskIndex
[       OK ] QAttentionTest.QAttentionNoMaskIndex (389 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_U8U8
[       OK ] QAttentionTest.QAttentionUnidirectional_U8U8 (11 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_U8S8
[       OK ] QAttentionTest.QAttentionUnidirectional_U8S8 (10 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_CUDA
[       OK ] QAttentionTest.QAttentionUnidirectional_CUDA (0 ms)
[ RUN      ] QAttentionTest.QAttentionPastState_u8u8
[       OK ] QAttentionTest.QAttentionPastState_u8u8 (2683 ms)
[ RUN      ] QAttentionTest.QAttentionPastState_u8s8
[       OK ] QAttentionTest.QAttentionPastState_u8s8 (2674 ms)
[ RUN      ] QAttentionTest.QAttentionPrunedModel
[       OK ] QAttentionTest.QAttentionPrunedModel (399 ms)
[ RUN      ] QAttentionTest.SharedPrepackedWeights
[       OK ] QAttentionTest.SharedPrepackedWeights (89 ms)
[----------] 13 tests from QAttentionTest (8047 ms total)

[----------] Global test environment tear-down
[==========] 14 tests from 2 test suites ran. (8175 ms total)
[  PASSED  ] 14 tests.
memleakdbg:
----- No memory leaks detected -----

Motivation and Context

@raoanag raoanag changed the base branch from main to WindowsAI-Old February 15, 2024 21:07
@raoanag raoanag changed the base branch from WindowsAI-Old to WindowsAI-dev February 23, 2024 20:11
@raoanag raoanag force-pushed the updateCausal branch 2 times, most recently from 3d81068 to 15dc524 Compare February 28, 2024 01:18
@raoanag raoanag changed the title Update causal DMLEP QAttention update causal Feb 28, 2024
@raoanag
Copy link
Contributor Author

raoanag commented Feb 28, 2024

Windows CI GPU Pipeline

@raoanag raoanag marked this pull request as ready for review February 28, 2024 17:28
@sumitsays
Copy link
Contributor

The default -10000.0 MaskFilterValue was originally taken from CPU EP when it used to use that value. But now that it started using std::numeric_limits<float>::lowest(), can you please update it at line 412 and verify the test passes? Ideally it should pass with std::numeric_limits<float>::lowest() as well. If yes, then please update it to std::numeric_limits<float>::lowest() to make it consist with CPU EP.

@@ -407,7 +416,7 @@ class DmlOperatorQAttention : public DmlOperator
mhaOperatorDesc.RelativePositionBiasTensor = nullptr;
mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex];
mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Scale, gsl::narrow_cast<float>(1.0f / std::sqrt(headSize)));
mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, -10'000.0f);
mhaOperatorDesc.MaskFilterValue = std::numeric_limits<float>::lowest();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should still give precedence to the user provided value and we should just change the default value to lowest().
mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, std::numeric_limits<float>::lowest());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the CPU and CUDA EPs change to lowest() for the default? The contrib ops age still mentions -10000 as the default https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QAttention

Copy link
Contributor Author

@raoanag raoanag Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For unidirectional that does not seem to be the case as per the reference, https://github.com/microsoft/onnxruntime/blob/d5606cd7ee394ba9444ef509021720ebe63c9856/onnxruntime/contrib_ops/cpu/bert/attention_helper.h#L142C1-L149C6

I will add std::numeric_limits<float>::lowest() only for unidirectional set.

@raoanag raoanag merged commit 21ba803 into WindowsAI-dev Mar 4, 2024
41 of 52 checks passed
@raoanag raoanag deleted the updateCausal branch March 4, 2024 17:21
@raoanag raoanag restored the updateCausal branch March 4, 2024 17:38
raoanag added a commit that referenced this pull request Mar 4, 2024
raoanag added a commit that referenced this pull request Mar 4, 2024
### Description
Bug Fix for `QAttentionTest.QAttentionPastState*` test failures 

The change generates causal mask which is an upper Triangular Boolean
Matrix as input to MHA mask. DML internally adds `maskFilterValue` to
the "off" bits in the mask and sets the "on" bits to 0.
```
Note: Google Test filter = *QAttention*
[==========] Running 14 tests from 2 test suites.
[----------] Global test environment set-up.
[----------] 1 test from CPU_U8S8_Precision_Tests
[ RUN      ] CPU_U8S8_Precision_Tests.QAttention
[       OK ] CPU_U8S8_Precision_Tests.QAttention (124 ms)
[----------] 1 test from CPU_U8S8_Precision_Tests (124 ms total)

[----------] 13 tests from QAttentionTest
[ RUN      ] QAttentionTest.QAttentionBatch1
[       OK ] QAttentionTest.QAttentionBatch1 (531 ms)
[ RUN      ] QAttentionTest.QAttentionBatch1_Float16
[       OK ] QAttentionTest.QAttentionBatch1_Float16 (0 ms)
[ RUN      ] QAttentionTest.QAttentionBatch2
[       OK ] QAttentionTest.QAttentionBatch2 (441 ms)
[ RUN      ] QAttentionTest.QAttentionMaskPartialSequence
[       OK ] QAttentionTest.QAttentionMaskPartialSequence (410 ms)
[ RUN      ] QAttentionTest.QAttentionMaskExceedSequence
[       OK ] QAttentionTest.QAttentionMaskExceedSequence (398 ms)
[ RUN      ] QAttentionTest.QAttentionNoMaskIndex
[       OK ] QAttentionTest.QAttentionNoMaskIndex (389 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_U8U8
[       OK ] QAttentionTest.QAttentionUnidirectional_U8U8 (11 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_U8S8
[       OK ] QAttentionTest.QAttentionUnidirectional_U8S8 (10 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_CUDA
[       OK ] QAttentionTest.QAttentionUnidirectional_CUDA (0 ms)
[ RUN      ] QAttentionTest.QAttentionPastState_u8u8
[       OK ] QAttentionTest.QAttentionPastState_u8u8 (2683 ms)
[ RUN      ] QAttentionTest.QAttentionPastState_u8s8
[       OK ] QAttentionTest.QAttentionPastState_u8s8 (2674 ms)
[ RUN      ] QAttentionTest.QAttentionPrunedModel
[       OK ] QAttentionTest.QAttentionPrunedModel (399 ms)
[ RUN      ] QAttentionTest.SharedPrepackedWeights
[       OK ] QAttentionTest.SharedPrepackedWeights (89 ms)
[----------] 13 tests from QAttentionTest (8047 ms total)

[----------] Global test environment tear-down
[==========] 14 tests from 2 test suites ran. (8175 ms total)
[  PASSED  ] 14 tests.
memleakdbg:
----- No memory leaks detected -----
```

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
raoanag added a commit that referenced this pull request Mar 9, 2024
### Description
Bug Fix for `QAttentionTest.QAttentionPastState*` test failures 

The change generates causal mask which is an upper Triangular Boolean
Matrix as input to MHA mask. DML internally adds `maskFilterValue` to
the "off" bits in the mask and sets the "on" bits to 0.
```
Note: Google Test filter = *QAttention*
[==========] Running 14 tests from 2 test suites.
[----------] Global test environment set-up.
[----------] 1 test from CPU_U8S8_Precision_Tests
[ RUN      ] CPU_U8S8_Precision_Tests.QAttention
[       OK ] CPU_U8S8_Precision_Tests.QAttention (124 ms)
[----------] 1 test from CPU_U8S8_Precision_Tests (124 ms total)

[----------] 13 tests from QAttentionTest
[ RUN      ] QAttentionTest.QAttentionBatch1
[       OK ] QAttentionTest.QAttentionBatch1 (531 ms)
[ RUN      ] QAttentionTest.QAttentionBatch1_Float16
[       OK ] QAttentionTest.QAttentionBatch1_Float16 (0 ms)
[ RUN      ] QAttentionTest.QAttentionBatch2
[       OK ] QAttentionTest.QAttentionBatch2 (441 ms)
[ RUN      ] QAttentionTest.QAttentionMaskPartialSequence
[       OK ] QAttentionTest.QAttentionMaskPartialSequence (410 ms)
[ RUN      ] QAttentionTest.QAttentionMaskExceedSequence
[       OK ] QAttentionTest.QAttentionMaskExceedSequence (398 ms)
[ RUN      ] QAttentionTest.QAttentionNoMaskIndex
[       OK ] QAttentionTest.QAttentionNoMaskIndex (389 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_U8U8
[       OK ] QAttentionTest.QAttentionUnidirectional_U8U8 (11 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_U8S8
[       OK ] QAttentionTest.QAttentionUnidirectional_U8S8 (10 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_CUDA
[       OK ] QAttentionTest.QAttentionUnidirectional_CUDA (0 ms)
[ RUN      ] QAttentionTest.QAttentionPastState_u8u8
[       OK ] QAttentionTest.QAttentionPastState_u8u8 (2683 ms)
[ RUN      ] QAttentionTest.QAttentionPastState_u8s8
[       OK ] QAttentionTest.QAttentionPastState_u8s8 (2674 ms)
[ RUN      ] QAttentionTest.QAttentionPrunedModel
[       OK ] QAttentionTest.QAttentionPrunedModel (399 ms)
[ RUN      ] QAttentionTest.SharedPrepackedWeights
[       OK ] QAttentionTest.SharedPrepackedWeights (89 ms)
[----------] 13 tests from QAttentionTest (8047 ms total)

[----------] Global test environment tear-down
[==========] 14 tests from 2 test suites ran. (8175 ms total)
[  PASSED  ] 14 tests.
memleakdbg:
----- No memory leaks detected -----
```

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants