Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DML] QAttention #19766

Merged
merged 7 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,7 @@ Do not modify directly.*
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearAveragePool|*in* X:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearConcat|*in* Y_scale:**TF**<br> *in* Y_zero_point:**T8**<br> *in* inputs:**TV**<br> *out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)<br/> **TF** = tensor(float)<br/> **TV** = tensor(float), tensor(int8), tensor(uint8)|
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,4 @@
}

DML_OP_DEFINE_CREATION_FUNCTION(QLinearSigmoid, DmlOperatorQLinearSigmoid);
} // namespace Dml
} // namespace Dml

Check warning on line 181 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp:181: At least two spaces is best between code and comments [whitespace/comments] [2]

Check warning on line 181 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Could not find a newline character at the end of the file. [whitespace/ending_newline] [5] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp:181: Could not find a newline character at the end of the file. [whitespace/ending_newline] [5]
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Resize19);

DML_OP_EXTERN_CREATION_FUNCTION(Shape);
DML_OP_EXTERN_CREATION_FUNCTION(Size);
DML_OP_EXTERN_CREATION_FUNCTION(QAttention);
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
DML_OP_EXTERN_CREATION_FUNCTION(MultiHeadAttention);
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
Expand All @@ -537,6 +538,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Pad);
DML_OP_EXTERN_QUERY_FUNCTION(LayerNormalization);
DML_OP_EXTERN_QUERY_FUNCTION(SkipLayerNormalization);
DML_OP_EXTERN_QUERY_FUNCTION(QLinearSigmoid);
DML_OP_EXTERN_QUERY_FUNCTION(QAttention);
DML_OP_EXTERN_QUERY_FUNCTION(Attention);

constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
Expand Down Expand Up @@ -614,15 +616,23 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerN
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8};

constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQAttention = {
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Float16to32,
SupportedTensorDataTypes::Int32
};
raoanag marked this conversation as resolved.
Show resolved Hide resolved

constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool};

constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearMatMul = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit
};

constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMulIntegerToFloat = {
Expand All @@ -632,9 +642,9 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMul
};

constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinearConv = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Int32
};

Expand Down Expand Up @@ -1069,6 +1079,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, DynamicQuantizeMatMul, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, FusedMatMulActivation, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
{REG_INFO_MS( 1, QAttention, typeNameListFour, supportedTypeListQAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQAttention)},
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2802,6 +2802,48 @@ namespace OperatorHelper
m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes);
}

std::vector<EdgeShapes> QAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 5);

auto queryShape = shapeInfo.GetInputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3);

auto weightShape = shapeInfo.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(weightShape.size() == 2);
ML_CHECK_VALID_ARGUMENT(weightShape[1] % 3 == 0);

const uint32_t batchSize = queryShape[0];
const uint32_t sequenceLength = queryShape[1];
const uint32_t hiddenSize = weightShape[1] / 3;
const uint32_t headSize = hiddenSize / m_numHeads;

std::vector<EdgeShapes> outputShapes(2);

outputShapes[0] = EdgeShapes({batchSize, sequenceLength, hiddenSize});

uint32_t totalSequenceLength = sequenceLength;
if (shapeInfo.IsInputValid(8))
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputTensorDimensionCount(8) == 5);
const uint32_t pastSequenceLength = shapeInfo.GetInputTensorShape(8)[3];
totalSequenceLength += pastSequenceLength;
}

if (shapeInfo.IsOutputValid(1))
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.IsInputValid(8));
outputShapes[1] = EdgeShapes({2, batchSize, m_numHeads, totalSequenceLength, headSize});
}

return outputShapes;
}

void QAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation)
{
m_numHeads = gsl::narrow_cast<uint32_t>(kernelInformation.GetAttributes().GetAttribute<int64_t>(AttrName::NumHeads));
}

std::vector<EdgeShapes> SkipLayerNormHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,22 @@ class AttentionHelper
std::vector<int32_t> m_qkvHiddenSizes;
};

class QAttentionHelper
{
public:
template <typename Info_t, typename Shape_t>
QAttentionHelper(const Info_t& info, const Shape_t& shapeInfo)
{
Initialize(KernelInformationAdapter(info));
}

std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;

private:
void Initialize(const IKernelInformationAdapter& kernelInformation);
uint32_t m_numHeads;
};

class SkipLayerNormHelper
{
public:
Expand Down Expand Up @@ -1699,6 +1715,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QAttention = QAttentionHelper;
using ShapeInferenceHelper_Attention = AttentionHelper;
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ namespace OperatorHelper
static const int sc_sinceVer_FusedMatMul = 1;
static const int sc_sinceVer_FusedMatMulActivation = 1;
static const int sc_sinceVer_QLinearSigmoid = 1;
static const int sc_sinceVer_QAttention = 1;
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_MatMulIntegerToFloat = 1;
static const int sc_sinceVer_MultiHeadAttention = 1;
Expand Down
60 changes: 56 additions & 4 deletions onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
enum class EP : char {
CPU,
CUDA,
DNNL
DNNL,
DML
};

// input: [batch_size, sequence_length, hidden_size]
Expand Down Expand Up @@ -111,7 +112,9 @@
execution_providers.push_back(DefaultCudaExecutionProvider());
} else if constexpr (ep == EP::CPU) {
execution_providers.push_back(DefaultCpuExecutionProvider());
} else { // onednn ep
} else if constexpr (ep == EP::DML) {
execution_providers.push_back(DefaultDmlExecutionProvider());
} else { // onednn ep
execution_providers.push_back(DefaultDnnlExecutionProvider());
}

Expand Down Expand Up @@ -192,6 +195,52 @@
#endif
}

static void RunQAttentionDML(
const std::vector<float>& input_data,
const std::vector<float>& weights_data,
const std::vector<float>& bias_data,
const std::vector<int32_t>& mask_index_data,
const std::vector<float>& output_data,
int batch_size,
int sequence_length,
int hidden_size,
int number_of_heads,
bool use_special_quantize_parameter = true,
bool is_unidirectional = false,
int input_hidden_size = 0) {
// Return without running code if USE_DML is not defined
#ifdef USE_DML
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get());
if (enable_dml) {
quantization::Params<uint8_t> input_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
quantization::Params<int8_t> weights_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
if (use_special_quantize_parameter) {
input_quant_params.scale = 0.1f;
weights_quant_params.scale = 0.1f;
input_quant_params.zero_point = 128;
weights_quant_params.zero_point = 1;
}

RunQAttention<uint8_t, int8_t, EP::DML>(
input_data, weights_data, bias_data, mask_index_data, output_data, input_quant_params, weights_quant_params,
batch_size, sequence_length, hidden_size, number_of_heads, is_unidirectional, false, input_hidden_size);
}
#else
ORT_UNUSED_PARAMETER(input_data);
ORT_UNUSED_PARAMETER(weights_data);
ORT_UNUSED_PARAMETER(bias_data);
ORT_UNUSED_PARAMETER(mask_index_data);
ORT_UNUSED_PARAMETER(output_data);
ORT_UNUSED_PARAMETER(batch_size);
ORT_UNUSED_PARAMETER(sequence_length);
ORT_UNUSED_PARAMETER(hidden_size);
ORT_UNUSED_PARAMETER(number_of_heads);
ORT_UNUSED_PARAMETER(use_special_quantize_parameter);
ORT_UNUSED_PARAMETER(is_unidirectional);
ORT_UNUSED_PARAMETER(input_hidden_size);
#endif
}

static void RunQAttentionU8U8(
const std::vector<float>& input_data,
const std::vector<float>& weights_data,
Expand Down Expand Up @@ -272,6 +321,9 @@
RunQAttentionDNNL(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_special_quantize_parameter, is_unidirectional, input_hidden_size);
RunQAttentionDML(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_special_quantize_parameter, is_unidirectional, input_hidden_size);
}

// ONEDNN EP only supports 2D raw mask
Expand Down Expand Up @@ -859,8 +911,8 @@
std::vector<int64_t> input_dims{batch, seq_len, hidden_size};
std::vector<InputT> input_data = random.Gaussian<InputT>(input_dims, input_mean, static_cast<InputT>(input_range / 6), input_min, input_max);

constexpr WeightT weight_min = std::numeric_limits<WeightT>::min();
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max();
constexpr WeightT weight_min = std::is_same_v<WeightT, int8_t> ? std::numeric_limits<int8_t>::min() / 2 : std::numeric_limits<WeightT>::min();

Check warning on line 914 in onnxruntime/test/contrib_ops/quantize_attention_op_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/contrib_ops/quantize_attention_op_test.cc:914: Lines should be <= 120 characters long [whitespace/line_length] [2]
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max() / 2;
constexpr int32_t weight_range = weight_max - weight_min;

std::vector<WeightT> weight_zero_point(weight_scale_zp_size);
Expand Down
Loading