Skip to content

Commit

Permalink
Enabling support for QAttention (#18326)
Browse files Browse the repository at this point in the history
[Cherry Pick Reviewed]

#16837
#16851
#17947

### Description
Enabling support for `Past`, `Present` and `unidirectional` for
[QAttention](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QAttention)
Contrib Op



```
Note: Google Test filter = *QAttention*
[==========] Running 14 tests from 2 test suites.
[----------] Global test environment set-up.
[----------] 1 test from CPU_U8S8_Precision_Tests
[ RUN      ] CPU_U8S8_Precision_Tests.QAttention
[       OK ] CPU_U8S8_Precision_Tests.QAttention (104 ms)
[----------] 1 test from CPU_U8S8_Precision_Tests (105 ms total)

[----------] 13 tests from QAttentionTest
[ RUN      ] QAttentionTest.QAttentionBatch1
[       OK ] QAttentionTest.QAttentionBatch1 (255 ms)
[ RUN      ] QAttentionTest.QAttentionBatch1_Float16
[       OK ] QAttentionTest.QAttentionBatch1_Float16 (0 ms)
[ RUN      ] QAttentionTest.QAttentionBatch2
[       OK ] QAttentionTest.QAttentionBatch2 (201 ms)
[ RUN      ] QAttentionTest.QAttentionMaskPartialSequence
[       OK ] QAttentionTest.QAttentionMaskPartialSequence (197 ms)
[ RUN      ] QAttentionTest.QAttentionMaskExceedSequence
[       OK ] QAttentionTest.QAttentionMaskExceedSequence (192 ms)
[ RUN      ] QAttentionTest.QAttentionNoMaskIndex
[       OK ] QAttentionTest.QAttentionNoMaskIndex (186 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_U8U8
[       OK ] QAttentionTest.QAttentionUnidirectional_U8U8 (9 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_U8S8
[       OK ] QAttentionTest.QAttentionUnidirectional_U8S8 (9 ms)
[ RUN      ] QAttentionTest.QAttentionUnidirectional_CUDA
[       OK ] QAttentionTest.QAttentionUnidirectional_CUDA (0 ms)
[ RUN      ] QAttentionTest.QAttentionPastState_u8u8
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.75743968039751053, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to 0.67312467098236084,
cur_actual[i] evaluates to -0.084315009415149689, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:0
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.75743968039751053, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to 0.67312467098236084,
cur_actual[i] evaluates to -0.084315009415149689, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:0
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.03001787792891264, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to -0.021467097103595734,
cur_actual[i] evaluates to 0.008550780825316906, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:0
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.03001787792891264, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to -0.021467097103595734,
cur_actual[i] evaluates to 0.008550780825316906, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:0
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
[  FAILED  ] QAttentionTest.QAttentionPastState_u8u8 (2067 ms)
[ RUN      ] QAttentionTest.QAttentionPastState_u8s8
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.74043640494346619, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to 0.65650326013565063,
cur_actual[i] evaluates to -0.083933144807815552, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:0
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.081788420677185059, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to 1.0076344013214111,
cur_actual[i] evaluates to 1.0894228219985962, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:965
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.74043640494346619, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to 0.65650326013565063,
cur_actual[i] evaluates to -0.083933144807815552, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:0
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.081788420677185059, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to 1.0076344013214111,
cur_actual[i] evaluates to 1.0894228219985962, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:965
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.024714200757443905, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to -0.016048312187194824,
cur_actual[i] evaluates to 0.0086658885702490807, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:0
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.0092324763536453247, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to 0.24175386130809784,
cur_actual[i] evaluates to 0.25098633766174316, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:979
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.024714200757443905, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to -0.016048312187194824,
cur_actual[i] evaluates to 0.0086658885702490807, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:0
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.0092324763536453247, which exceeds *(params.absolute_error), where
cur_expected[i] evaluates to 0.24175386130809784,
cur_actual[i] evaluates to 0.25098633766174316, and
*(params.absolute_error) evaluates to 0.00019999999494757503.
i:979
Google Test trace:
C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider
C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560
[  FAILED  ] QAttentionTest.QAttentionPastState_u8s8 (2079 ms)
[ RUN      ] QAttentionTest.QAttentionPrunedModel
[       OK ] QAttentionTest.QAttentionPrunedModel (206 ms)
[ RUN      ] QAttentionTest.SharedPrepackedWeights
[       OK ] QAttentionTest.SharedPrepackedWeights (79 ms)
[----------] 13 tests from QAttentionTest (5492 ms total)

[----------] Global test environment tear-down
[==========] 14 tests from 2 test suites ran. (5600 ms total)
[  PASSED  ] 12 tests.
[  FAILED  ] 2 tests, listed below:
[  FAILED  ] QAttentionTest.QAttentionPastState_u8u8
[  FAILED  ] QAttentionTest.QAttentionPastState_u8s8

 2 FAILED TESTS
memleakdbg:
----- No memory leaks detected -----
```


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Xiang Zhang <[email protected]>
  • Loading branch information
raoanag and zhangxiang1993 authored Nov 9, 2023
1 parent deed125 commit b2768bb
Show file tree
Hide file tree
Showing 11 changed files with 833 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_

constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA {
"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING",
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING),
DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
13,
DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
Expand Down Expand Up @@ -1923,7 +1923,7 @@ constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_

constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT),
DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
8,
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,12 @@ void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*o
return;
}

// `past_present_share_buffer == 1` is not supported yet
if (attributes.GetOptionalAttribute<int32_t>(AttrName::PastPresentShareBuffer, 0) != 0)
{
return;
}

*isSupported = true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr;
matrixMultiplyIntergerToFloatOperatorDesc.OutputTensor = &outputDescs[0];

const DML_OPERATOR_DESC opDesc2{ (DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matrixMultiplyIntergerToFloatOperatorDesc};
const DML_OPERATOR_DESC opDesc2{ DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matrixMultiplyIntergerToFloatOperatorDesc};

MLOperatorGraphDesc operatorGraphDesc = {};
std::vector<const DML_OPERATOR_DESC*> opDescs{&opDesc1, &opDesc2};
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,4 @@ void CALLBACK QueryQLinearSigmoid(IMLOperatorSupportQueryContextPrivate* context
}

DML_OP_DEFINE_CREATION_FUNCTION(QLinearSigmoid, DmlOperatorQLinearSigmoid);
} // namespace Dml
} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger);
DML_OP_EXTERN_CREATION_FUNCTION(Trilu);
DML_OP_EXTERN_CREATION_FUNCTION(Shape);
DML_OP_EXTERN_CREATION_FUNCTION(Size);
DML_OP_EXTERN_CREATION_FUNCTION(QAttention);
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
DML_OP_EXTERN_CREATION_FUNCTION(MultiHeadAttention);
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
Expand All @@ -527,6 +528,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Pad);
DML_OP_EXTERN_QUERY_FUNCTION(LayerNormalization);
DML_OP_EXTERN_QUERY_FUNCTION(SkipLayerNormalization);
DML_OP_EXTERN_QUERY_FUNCTION(QLinearSigmoid);
DML_OP_EXTERN_QUERY_FUNCTION(QAttention);
DML_OP_EXTERN_QUERY_FUNCTION(Attention);

constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
Expand Down Expand Up @@ -602,14 +604,22 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerN
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8};

constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQAttention = {
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Float16to32,
SupportedTensorDataTypes::Int32
};

constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool};

constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearMatMul = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit
};

constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMulIntegerToFloat = {
Expand All @@ -619,9 +629,9 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMul
};

constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinearConv = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Int32
};

Expand Down Expand Up @@ -1035,6 +1045,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, DynamicQuantizeMatMul, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, FusedMatMulActivation, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
{REG_INFO_MS( 1, QAttention, typeNameListFour, supportedTypeListQAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQAttention)},
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearConcat, typeNameListQLinearConcat, supportedTypeListQLinearConcat, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ namespace AttrName
static constexpr const char* QkvHiddenSizes = "qkv_hidden_sizes";
static constexpr const char* Unidirectional = "unidirectional";
static constexpr const char* NumHeads = "num_heads";
static constexpr const char* PastPresentShareBuffer = "past_present_share_buffer";

static constexpr const char* FusedActivation = "fused_activation";
static constexpr const char* FusedActivationDomain = "fused_activation_domain";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2744,6 +2744,48 @@ namespace OperatorHelper
m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes);
}

std::vector<EdgeShapes> QAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 5);

auto queryShape = shapeInfo.GetInputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3);

auto weightShape = shapeInfo.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(weightShape.size() == 2);
ML_CHECK_VALID_ARGUMENT(weightShape[1] % 3 == 0);

const uint32_t batchSize = queryShape[0];
const uint32_t sequenceLength = queryShape[1];
const uint32_t hiddenSize = weightShape[1] / 3;
const uint32_t headSize = hiddenSize / m_numHeads;

std::vector<EdgeShapes> outputShapes(2);

outputShapes[0] = EdgeShapes({batchSize, sequenceLength, hiddenSize});

uint32_t totalSequenceLength = sequenceLength;
if (shapeInfo.IsInputValid(8))
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputTensorDimensionCount(8) == 5);
const uint32_t pastSequenceLength = shapeInfo.GetInputTensorShape(8)[3];
totalSequenceLength += pastSequenceLength;
}

if (shapeInfo.IsOutputValid(1))
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.IsInputValid(8));
outputShapes[1] = EdgeShapes({2, batchSize, m_numHeads, totalSequenceLength, headSize});
}

return outputShapes;
}

void QAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation)
{
m_numHeads = gsl::narrow_cast<uint32_t>(kernelInformation.GetAttributes().GetAttribute<int64_t>(AttrName::NumHeads));
}

std::vector<EdgeShapes> SkipLayerNormHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,22 @@ class AttentionHelper
std::vector<int32_t> m_qkvHiddenSizes;
};

class QAttentionHelper
{
public:
template <typename Info_t, typename Shape_t>
QAttentionHelper(const Info_t& info, const Shape_t& shapeInfo)
{
Initialize(KernelInformationAdapter(info));
}

std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;

private:
void Initialize(const IKernelInformationAdapter& kernelInformation);
uint32_t m_numHeads;
};

class SkipLayerNormHelper
{
public:
Expand Down Expand Up @@ -1630,6 +1646,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QAttention = QAttentionHelper;
using ShapeInferenceHelper_Attention = AttentionHelper;
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ namespace OperatorHelper
static const int sc_sinceVer_FusedMatMul = 1;
static const int sc_sinceVer_FusedMatMulActivation = 1;
static const int sc_sinceVer_QLinearSigmoid = 1;
static const int sc_sinceVer_QAttention = 1;
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_MatMulIntegerToFloat = 1;
static const int sc_sinceVer_MultiHeadAttention = 1;
Expand Down
56 changes: 54 additions & 2 deletions onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace test {
enum class EP : char {
CPU,
CUDA,
DNNL
DNNL,
DML
};

// input: [batch_size, sequence_length, hidden_size]
Expand Down Expand Up @@ -111,7 +112,9 @@ void RunQAttention(const std::vector<float>& input_data,
execution_providers.push_back(DefaultCudaExecutionProvider());
} else if constexpr (ep == EP::CPU) {
execution_providers.push_back(DefaultCpuExecutionProvider());
} else { // onednn ep
} else if constexpr (ep == EP::DML) {
execution_providers.push_back(DefaultDmlExecutionProvider());
} else{ // onednn ep
execution_providers.push_back(DefaultDnnlExecutionProvider());
}

Expand Down Expand Up @@ -192,6 +195,52 @@ static void RunQAttentionDNNL(
#endif
}

static void RunQAttentionDML(
const std::vector<float>& input_data,
const std::vector<float>& weights_data,
const std::vector<float>& bias_data,
const std::vector<int32_t>& mask_index_data,
const std::vector<float>& output_data,
int batch_size,
int sequence_length,
int hidden_size,
int number_of_heads,
bool use_special_quantize_parameter = true,
bool is_unidirectional = false,
int input_hidden_size = 0) {
// Return without running code if USE_DML is not defined
#ifdef USE_DML
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get());
if (enable_dml) {
quantization::Params<uint8_t> input_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
quantization::Params<int8_t> weights_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
if (use_special_quantize_parameter) {
input_quant_params.scale = 0.1f;
weights_quant_params.scale = 0.1f;
input_quant_params.zero_point = 128;
weights_quant_params.zero_point = 1;
}

RunQAttention<uint8_t, int8_t, EP::DML>(
input_data, weights_data, bias_data, mask_index_data, output_data, input_quant_params, weights_quant_params,
batch_size, sequence_length, hidden_size, number_of_heads, is_unidirectional, false, input_hidden_size);
}
#else
ORT_UNUSED_PARAMETER(input_data);
ORT_UNUSED_PARAMETER(weights_data);
ORT_UNUSED_PARAMETER(bias_data);
ORT_UNUSED_PARAMETER(mask_index_data);
ORT_UNUSED_PARAMETER(output_data);
ORT_UNUSED_PARAMETER(batch_size);
ORT_UNUSED_PARAMETER(sequence_length);
ORT_UNUSED_PARAMETER(hidden_size);
ORT_UNUSED_PARAMETER(number_of_heads);
ORT_UNUSED_PARAMETER(use_special_quantize_parameter);
ORT_UNUSED_PARAMETER(is_unidirectional);
ORT_UNUSED_PARAMETER(input_hidden_size);
#endif
}

static void RunQAttentionU8U8(
const std::vector<float>& input_data,
const std::vector<float>& weights_data,
Expand Down Expand Up @@ -272,6 +321,9 @@ static void RunQAttentionAll(
RunQAttentionDNNL(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_special_quantize_parameter, is_unidirectional, input_hidden_size);
RunQAttentionDML(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_special_quantize_parameter, is_unidirectional, input_hidden_size);
}

// ONEDNN EP only supports 2D raw mask
Expand Down

0 comments on commit b2768bb

Please sign in to comment.