Skip to content

Commit

Permalink
DMLEP QAttention update causal (#19533)
Browse files Browse the repository at this point in the history
### Description
Bug Fix for `QAttentionTest.QAttentionPastState*` test failures 

The change generates causal mask which is an upper Triangular Boolean
Matrix as input to MHA mask. DML internally adds `maskFilterValue` to
the "off" bits in the mask and sets the "on" bits to 0.
```
Note: Google Test filter = *QAttention*
[==========] Running 14 tests from 2 test suites.
[----------] Global test environment set-up.
[----------] 1 test from CPU_U8S8_Precision_Tests
[ RUN      ] CPU_U8S8_Precision_Tests.QAttention
[       OK ] CPU_U8S8_Precision_Tests.QAttention (124 ms)
[----------] 1 test from CPU_U8S8_Precision_Tests (124 ms total)

[----------] 13 tests from QAttentionTest
[ RUN      ] QAttentionTest.QAttentionBatch1
[       OK ] QAttentionTest.QAttentionBatch1 (531 ms)
[ RUN      ] QAttentionTest.QAttentionBatch1_Float16
[       OK ] QAttentionTest.QAttentionBatch1_Float16 (0 ms)
[ RUN      ] QAttentionTest.QAttentionBatch2
[       OK ] QAttentionTest.QAttentionBatch2 (441 ms)
[ RUN      ] QAttentionTest.QAttentionMaskPartialSequence
[       OK ] QAttentionTest.QAttentionMaskPartialSequence (410 ms)
[ RUN      ] QAttentionTest.QAttentionMaskExceedSequence
[       OK ] QAttentionTest.QAttentionMaskExceedSequence (398 ms)
[ RUN      ] QAttentionTest.QAttentionNoMaskIndex
[       OK ] QAttentionTest.QAttentionNoMaskIndex (389 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_U8U8
[       OK ] QAttentionTest.QAttentionUnidirectional_U8U8 (11 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_U8S8
[       OK ] QAttentionTest.QAttentionUnidirectional_U8S8 (10 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_CUDA
[       OK ] QAttentionTest.QAttentionUnidirectional_CUDA (0 ms)
[ RUN      ] QAttentionTest.QAttentionPastState_u8u8
[       OK ] QAttentionTest.QAttentionPastState_u8u8 (2683 ms)
[ RUN      ] QAttentionTest.QAttentionPastState_u8s8
[       OK ] QAttentionTest.QAttentionPastState_u8s8 (2674 ms)
[ RUN      ] QAttentionTest.QAttentionPrunedModel
[       OK ] QAttentionTest.QAttentionPrunedModel (399 ms)
[ RUN      ] QAttentionTest.SharedPrepackedWeights
[       OK ] QAttentionTest.SharedPrepackedWeights (89 ms)
[----------] 13 tests from QAttentionTest (8047 ms total)

[----------] Global test environment tear-down
[==========] 14 tests from 2 test suites ran. (8175 ms total)
[  PASSED  ] 14 tests.
memleakdbg:
----- No memory leaks detected -----
```

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
raoanag authored Mar 4, 2024
1 parent fa19118 commit 21ba803
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -361,25 +361,30 @@ class DmlOperatorQAttention : public DmlOperator
const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc};
const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc};

// Causal Mask: [pastSequenceLength, pastSequenceLength + 1 ... pastSequenceLength + batchSize -1]
// Causal Mask: Upper Triangular Boolean Matrix
// Example: [[1, 0, 0, 0, 0],
// [1, 1, 0, 0, 0],
// [1, 1, 1, 0, 0],
// [1, 1, 1, 1, 0]]
// DML adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0
// passed to MHA as maskIndex Tensor when unidirectional == 1
std::array<uint32_t, 2> causalMaskOutputShape = {1, batchSize};
std::array<uint32_t, 4> causalMaskOutputShape = {1, 1, sequenceLength, pastSequenceLength + sequenceLength};
TensorDesc causalMaskTensorDesc;
DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC causalMaskOperatorDesc = {};
DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {};
DML_TENSOR_DESC namedcausalMaskTensorDesc;

if (unidirectional && !hasMask)
{
causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape);
namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc();
causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32;
causalMaskOperatorDesc.ValueStart.Int32 = pastSequenceLength;
causalMaskOperatorDesc.ValueDelta.Int32 = 1;
causalMaskOperatorDesc.DiagonalFillBegin = INT32_MIN;
causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1;
causalMaskOperatorDesc.Value.Int32 = 1;
causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc;

maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH;
maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN;
}
DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_FILL_VALUE_SEQUENCE, &causalMaskOperatorDesc };
DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc };

DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {};
std::array<uint32_t, 5> presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
Expand All @@ -393,7 +398,11 @@ class DmlOperatorQAttention : public DmlOperator

if (unidirectional && !hasMask)
{
mhaOperatorDesc.MaskTensor = &namedcausalMaskTensorDesc;
// Broadcast to MHA MaskTensor Shape
std::array<uint32_t, 4> mhaMaskTensorShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength};
TensorDesc broadcastedcausalMaskTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(MLOperatorTensorDataType::Int32, mhaMaskTensorShape, causalMaskOutputShape);
const DML_TENSOR_DESC namedbroadcastedcausalMaskTensorDesc = broadcastedcausalMaskTensorDesc.GetDmlDesc();
mhaOperatorDesc.MaskTensor = &namedbroadcastedcausalMaskTensorDesc;
}
else if (hasMaxSequenceMask)
{
Expand All @@ -407,7 +416,9 @@ class DmlOperatorQAttention : public DmlOperator
mhaOperatorDesc.RelativePositionBiasTensor = nullptr;
mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex];
mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Scale, gsl::narrow_cast<float>(1.0f / std::sqrt(headSize)));
mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, -10'000.0f);
// Set MaskFilterValue to lowest float for Causal Mask
mhaOperatorDesc.MaskFilterValue = unidirectional ? std::numeric_limits<float>::lowest() :
kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, -10'000.0f);
mhaOperatorDesc.HeadCount = numHeads;
mhaOperatorDesc.MaskType = maskType;
if (hasPast)
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -911,8 +911,8 @@ void TestQuantizedAttentionPastState(int64_t batch,
std::vector<int64_t> input_dims{batch, seq_len, hidden_size};
std::vector<InputT> input_data = random.Gaussian<InputT>(input_dims, input_mean, static_cast<InputT>(input_range / 6), input_min, input_max);

constexpr WeightT weight_min = std::numeric_limits<WeightT>::min();
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max();
constexpr WeightT weight_min = constexpr(std::is_same_v<WeightT, int8_t>) ? std::numeric_limits<int8_t>::min() / 2 : std::numeric_limits<WeightT>::min();
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max() / 2;
constexpr int32_t weight_range = weight_max - weight_min;

std::vector<WeightT> weight_zero_point(weight_scale_zp_size);
Expand Down

0 comments on commit 21ba803

Please sign in to comment.