-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DMLEP QAttention update causal #19533
Conversation
1f6de40
to
7722fc8
Compare
3d81068
to
15dc524
Compare
15dc524
to
6a904a7
Compare
onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp
Show resolved
Hide resolved
onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp
Show resolved
Hide resolved
The default |
onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp
Outdated
Show resolved
Hide resolved
@@ -407,7 +416,7 @@ class DmlOperatorQAttention : public DmlOperator | |||
mhaOperatorDesc.RelativePositionBiasTensor = nullptr; | |||
mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; | |||
mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Scale, gsl::narrow_cast<float>(1.0f / std::sqrt(headSize))); | |||
mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, -10'000.0f); | |||
mhaOperatorDesc.MaskFilterValue = std::numeric_limits<float>::lowest(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should still give precedence to the user provided value and we should just change the default value to lowest().
mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, std::numeric_limits<float>::lowest());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did the CPU and CUDA EPs change to lowest()
for the default? The contrib ops age still mentions -10000 as the default https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QAttention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For unidirectional that does not seem to be the case as per the reference, https://github.com/microsoft/onnxruntime/blob/d5606cd7ee394ba9444ef509021720ebe63c9856/onnxruntime/contrib_ops/cpu/bert/attention_helper.h#L142C1-L149C6
I will add std::numeric_limits<float>::lowest()
only for unidirectional set.
### Description Bug Fix for `QAttentionTest.QAttentionPastState*` test failures The change generates causal mask which is an upper Triangular Boolean Matrix as input to MHA mask. DML internally adds `maskFilterValue` to the "off" bits in the mask and sets the "on" bits to 0. ``` Note: Google Test filter = *QAttention* [==========] Running 14 tests from 2 test suites. [----------] Global test environment set-up. [----------] 1 test from CPU_U8S8_Precision_Tests [ RUN ] CPU_U8S8_Precision_Tests.QAttention [ OK ] CPU_U8S8_Precision_Tests.QAttention (124 ms) [----------] 1 test from CPU_U8S8_Precision_Tests (124 ms total) [----------] 13 tests from QAttentionTest [ RUN ] QAttentionTest.QAttentionBatch1 [ OK ] QAttentionTest.QAttentionBatch1 (531 ms) [ RUN ] QAttentionTest.QAttentionBatch1_Float16 [ OK ] QAttentionTest.QAttentionBatch1_Float16 (0 ms) [ RUN ] QAttentionTest.QAttentionBatch2 [ OK ] QAttentionTest.QAttentionBatch2 (441 ms) [ RUN ] QAttentionTest.QAttentionMaskPartialSequence [ OK ] QAttentionTest.QAttentionMaskPartialSequence (410 ms) [ RUN ] QAttentionTest.QAttentionMaskExceedSequence [ OK ] QAttentionTest.QAttentionMaskExceedSequence (398 ms) [ RUN ] QAttentionTest.QAttentionNoMaskIndex [ OK ] QAttentionTest.QAttentionNoMaskIndex (389 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_U8U8 [ OK ] QAttentionTest.QAttentionUnidirectional_U8U8 (11 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_U8S8 [ OK ] QAttentionTest.QAttentionUnidirectional_U8S8 (10 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_CUDA [ OK ] QAttentionTest.QAttentionUnidirectional_CUDA (0 ms) [ RUN ] QAttentionTest.QAttentionPastState_u8u8 [ OK ] QAttentionTest.QAttentionPastState_u8u8 (2683 ms) [ RUN ] QAttentionTest.QAttentionPastState_u8s8 [ OK ] QAttentionTest.QAttentionPastState_u8s8 (2674 ms) [ RUN ] QAttentionTest.QAttentionPrunedModel [ OK ] QAttentionTest.QAttentionPrunedModel (399 ms) [ RUN ] QAttentionTest.SharedPrepackedWeights [ OK ] QAttentionTest.SharedPrepackedWeights (89 ms) [----------] 13 tests from QAttentionTest (8047 ms total) [----------] Global test environment tear-down [==========] 14 tests from 2 test suites ran. (8175 ms total) [ PASSED ] 14 tests. memleakdbg: ----- No memory leaks detected ----- ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
### Description Bug Fix for `QAttentionTest.QAttentionPastState*` test failures The change generates causal mask which is an upper Triangular Boolean Matrix as input to MHA mask. DML internally adds `maskFilterValue` to the "off" bits in the mask and sets the "on" bits to 0. ``` Note: Google Test filter = *QAttention* [==========] Running 14 tests from 2 test suites. [----------] Global test environment set-up. [----------] 1 test from CPU_U8S8_Precision_Tests [ RUN ] CPU_U8S8_Precision_Tests.QAttention [ OK ] CPU_U8S8_Precision_Tests.QAttention (124 ms) [----------] 1 test from CPU_U8S8_Precision_Tests (124 ms total) [----------] 13 tests from QAttentionTest [ RUN ] QAttentionTest.QAttentionBatch1 [ OK ] QAttentionTest.QAttentionBatch1 (531 ms) [ RUN ] QAttentionTest.QAttentionBatch1_Float16 [ OK ] QAttentionTest.QAttentionBatch1_Float16 (0 ms) [ RUN ] QAttentionTest.QAttentionBatch2 [ OK ] QAttentionTest.QAttentionBatch2 (441 ms) [ RUN ] QAttentionTest.QAttentionMaskPartialSequence [ OK ] QAttentionTest.QAttentionMaskPartialSequence (410 ms) [ RUN ] QAttentionTest.QAttentionMaskExceedSequence [ OK ] QAttentionTest.QAttentionMaskExceedSequence (398 ms) [ RUN ] QAttentionTest.QAttentionNoMaskIndex [ OK ] QAttentionTest.QAttentionNoMaskIndex (389 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_U8U8 [ OK ] QAttentionTest.QAttentionUnidirectional_U8U8 (11 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_U8S8 [ OK ] QAttentionTest.QAttentionUnidirectional_U8S8 (10 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_CUDA [ OK ] QAttentionTest.QAttentionUnidirectional_CUDA (0 ms) [ RUN ] QAttentionTest.QAttentionPastState_u8u8 [ OK ] QAttentionTest.QAttentionPastState_u8u8 (2683 ms) [ RUN ] QAttentionTest.QAttentionPastState_u8s8 [ OK ] QAttentionTest.QAttentionPastState_u8s8 (2674 ms) [ RUN ] QAttentionTest.QAttentionPrunedModel [ OK ] QAttentionTest.QAttentionPrunedModel (399 ms) [ RUN ] QAttentionTest.SharedPrepackedWeights [ OK ] QAttentionTest.SharedPrepackedWeights (89 ms) [----------] 13 tests from QAttentionTest (8047 ms total) [----------] Global test environment tear-down [==========] 14 tests from 2 test suites ran. (8175 ms total) [ PASSED ] 14 tests. memleakdbg: ----- No memory leaks detected ----- ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
Description
Bug Fix for
QAttentionTest.QAttentionPastState*
test failuresThe change generates causal mask which is an upper Triangular Boolean Matrix as input to MHA mask. DML internally adds
maskFilterValue
to the "off" bits in the mask and sets the "on" bits to 0.Motivation and Context