Skip to content

Commit

Permalink
[DML] QAttention (#19766)
Browse files Browse the repository at this point in the history
### Description
DML Implementation for
[com.microsoft.QAttention](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QAttention)



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Xiang Zhang <[email protected]>
  • Loading branch information
raoanag and zhangxiang1993 authored Mar 11, 2024
1 parent 5479124 commit 89aa469
Show file tree
Hide file tree
Showing 8 changed files with 839 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,7 @@ Do not modify directly.*
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearAveragePool|*in* X:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearConcat|*in* Y_scale:**TF**<br> *in* Y_zero_point:**T8**<br> *in* inputs:**TV**<br> *out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)<br/> **TF** = tensor(float)<br/> **TV** = tensor(float), tensor(int8), tensor(uint8)|
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,4 @@ void CALLBACK QueryQLinearSigmoid(IMLOperatorSupportQueryContextPrivate* context
}

DML_OP_DEFINE_CREATION_FUNCTION(QLinearSigmoid, DmlOperatorQLinearSigmoid);
} // namespace Dml
} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Resize19);

DML_OP_EXTERN_CREATION_FUNCTION(Shape);
DML_OP_EXTERN_CREATION_FUNCTION(Size);
DML_OP_EXTERN_CREATION_FUNCTION(QAttention);
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
DML_OP_EXTERN_CREATION_FUNCTION(MultiHeadAttention);
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
Expand All @@ -537,6 +538,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Pad);
DML_OP_EXTERN_QUERY_FUNCTION(LayerNormalization);
DML_OP_EXTERN_QUERY_FUNCTION(SkipLayerNormalization);
DML_OP_EXTERN_QUERY_FUNCTION(QLinearSigmoid);
DML_OP_EXTERN_QUERY_FUNCTION(QAttention);
DML_OP_EXTERN_QUERY_FUNCTION(Attention);

constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
Expand Down Expand Up @@ -614,15 +616,23 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerN
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8};

constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQAttention = {
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Float16to32,
SupportedTensorDataTypes::Int32
};

constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool};

constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearMatMul = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit
};

constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMulIntegerToFloat = {
Expand All @@ -632,9 +642,9 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMul
};

constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinearConv = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Int32
};

Expand Down Expand Up @@ -1069,6 +1079,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, DynamicQuantizeMatMul, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, FusedMatMulActivation, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
{REG_INFO_MS( 1, QAttention, typeNameListFour, supportedTypeListQAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQAttention)},
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2802,6 +2802,48 @@ namespace OperatorHelper
m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes);
}

std::vector<EdgeShapes> QAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 5);

auto queryShape = shapeInfo.GetInputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3);

auto weightShape = shapeInfo.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(weightShape.size() == 2);
ML_CHECK_VALID_ARGUMENT(weightShape[1] % 3 == 0);

const uint32_t batchSize = queryShape[0];
const uint32_t sequenceLength = queryShape[1];
const uint32_t hiddenSize = weightShape[1] / 3;
const uint32_t headSize = hiddenSize / m_numHeads;

std::vector<EdgeShapes> outputShapes(2);

outputShapes[0] = EdgeShapes({batchSize, sequenceLength, hiddenSize});

uint32_t totalSequenceLength = sequenceLength;
if (shapeInfo.IsInputValid(8))
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputTensorDimensionCount(8) == 5);
const uint32_t pastSequenceLength = shapeInfo.GetInputTensorShape(8)[3];
totalSequenceLength += pastSequenceLength;
}

if (shapeInfo.IsOutputValid(1))
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.IsInputValid(8));
outputShapes[1] = EdgeShapes({2, batchSize, m_numHeads, totalSequenceLength, headSize});
}

return outputShapes;
}

void QAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation)
{
m_numHeads = gsl::narrow_cast<uint32_t>(kernelInformation.GetAttributes().GetAttribute<int64_t>(AttrName::NumHeads));
}

std::vector<EdgeShapes> SkipLayerNormHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,22 @@ class AttentionHelper
std::vector<int32_t> m_qkvHiddenSizes;
};

class QAttentionHelper
{
public:
template <typename Info_t, typename Shape_t>
QAttentionHelper(const Info_t& info, const Shape_t& shapeInfo)
{
Initialize(KernelInformationAdapter(info));
}

std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;

private:
void Initialize(const IKernelInformationAdapter& kernelInformation);
uint32_t m_numHeads;
};

class SkipLayerNormHelper
{
public:
Expand Down Expand Up @@ -1699,6 +1715,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QAttention = QAttentionHelper;
using ShapeInferenceHelper_Attention = AttentionHelper;
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ namespace OperatorHelper
static const int sc_sinceVer_FusedMatMul = 1;
static const int sc_sinceVer_FusedMatMulActivation = 1;
static const int sc_sinceVer_QLinearSigmoid = 1;
static const int sc_sinceVer_QAttention = 1;
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_MatMulIntegerToFloat = 1;
static const int sc_sinceVer_MultiHeadAttention = 1;
Expand Down
60 changes: 56 additions & 4 deletions onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace test {
enum class EP : char {
CPU,
CUDA,
DNNL
DNNL,
DML
};

// input: [batch_size, sequence_length, hidden_size]
Expand Down Expand Up @@ -111,7 +112,9 @@ void RunQAttention(const std::vector<float>& input_data,
execution_providers.push_back(DefaultCudaExecutionProvider());
} else if constexpr (ep == EP::CPU) {
execution_providers.push_back(DefaultCpuExecutionProvider());
} else { // onednn ep
} else if constexpr (ep == EP::DML) {
execution_providers.push_back(DefaultDmlExecutionProvider());
} else { // onednn ep
execution_providers.push_back(DefaultDnnlExecutionProvider());
}

Expand Down Expand Up @@ -192,6 +195,52 @@ static void RunQAttentionDNNL(
#endif
}

static void RunQAttentionDML(
const std::vector<float>& input_data,
const std::vector<float>& weights_data,
const std::vector<float>& bias_data,
const std::vector<int32_t>& mask_index_data,
const std::vector<float>& output_data,
int batch_size,
int sequence_length,
int hidden_size,
int number_of_heads,
bool use_special_quantize_parameter = true,
bool is_unidirectional = false,
int input_hidden_size = 0) {
// Return without running code if USE_DML is not defined
#ifdef USE_DML
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get());
if (enable_dml) {
quantization::Params<uint8_t> input_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
quantization::Params<int8_t> weights_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
if (use_special_quantize_parameter) {
input_quant_params.scale = 0.1f;
weights_quant_params.scale = 0.1f;
input_quant_params.zero_point = 128;
weights_quant_params.zero_point = 1;
}

RunQAttention<uint8_t, int8_t, EP::DML>(
input_data, weights_data, bias_data, mask_index_data, output_data, input_quant_params, weights_quant_params,
batch_size, sequence_length, hidden_size, number_of_heads, is_unidirectional, false, input_hidden_size);
}
#else
ORT_UNUSED_PARAMETER(input_data);
ORT_UNUSED_PARAMETER(weights_data);
ORT_UNUSED_PARAMETER(bias_data);
ORT_UNUSED_PARAMETER(mask_index_data);
ORT_UNUSED_PARAMETER(output_data);
ORT_UNUSED_PARAMETER(batch_size);
ORT_UNUSED_PARAMETER(sequence_length);
ORT_UNUSED_PARAMETER(hidden_size);
ORT_UNUSED_PARAMETER(number_of_heads);
ORT_UNUSED_PARAMETER(use_special_quantize_parameter);
ORT_UNUSED_PARAMETER(is_unidirectional);
ORT_UNUSED_PARAMETER(input_hidden_size);
#endif
}

static void RunQAttentionU8U8(
const std::vector<float>& input_data,
const std::vector<float>& weights_data,
Expand Down Expand Up @@ -272,6 +321,9 @@ static void RunQAttentionAll(
RunQAttentionDNNL(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_special_quantize_parameter, is_unidirectional, input_hidden_size);
RunQAttentionDML(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_special_quantize_parameter, is_unidirectional, input_hidden_size);
}

// ONEDNN EP only supports 2D raw mask
Expand Down Expand Up @@ -859,8 +911,8 @@ void TestQuantizedAttentionPastState(int64_t batch,
std::vector<int64_t> input_dims{batch, seq_len, hidden_size};
std::vector<InputT> input_data = random.Gaussian<InputT>(input_dims, input_mean, static_cast<InputT>(input_range / 6), input_min, input_max);

constexpr WeightT weight_min = std::numeric_limits<WeightT>::min();
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max();
constexpr WeightT weight_min = std::is_same_v<WeightT, int8_t> ? std::numeric_limits<int8_t>::min() / 2 : std::numeric_limits<WeightT>::min();
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max() / 2;
constexpr int32_t weight_range = weight_max - weight_min;

std::vector<WeightT> weight_zero_point(weight_scale_zp_size);
Expand Down

0 comments on commit 89aa469

Please sign in to comment.