Skip to content

Commit

Permalink
Add QLinearConcat for DML EP (#16971) (#18268)
Browse files Browse the repository at this point in the history
### Description
[Cherry Pick Reviewed]
```
[ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms)
[ RUN ] QLinearConcatS8.InputOne_Dynamic
[ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms)
[ RUN ] QLinearConcatS8.InputOne_Const
[ OK ] QLinearConcatS8.InputOne_Const (255 ms)
[----------] 11 tests from QLinearConcatS8 (3385 ms total)

[----------] Global test environment tear-down
[==========] 21 tests from 3 test suites ran. (9355 ms total)
[ PASSED ] 21 tests.
```
[#16971](#16971)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Xiang Zhang <[email protected]>
  • Loading branch information
2 people authored and jeffbloo committed Jan 4, 2024
1 parent 9bbe425 commit 9ff5e3b
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "precomp.h"

namespace Dml
{
// QLinearConcat = Dequantize + Join + Quantize
class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper
{
// This order matches the ONNX schema.
enum OnnxInputIndex
{
YScale,
YZeroPoint,
Count,
};

public:
DmlOperatorQLinearConcat(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
QLinearConcatHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
DmlOperator::Initialize(kernelCreationContext);

auto outputShape = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0);

// inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)}
uint32_t inputDefinitionCount = kernelCreationContext.GetInputCount();
ML_CHECK_VALID_ARGUMENT(inputDefinitionCount >= 5, "Require at least 5 inputs.");
ML_CHECK_VALID_ARGUMENT((inputDefinitionCount - 2) % 3 == 0, "Each input must be (tensor, scale, zero_point) tuple!");

uint32_t inputCount = (inputDefinitionCount - 2) / 3;

auto yScaleDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YScale).tensorDataType;
auto yZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YZeroPoint).tensorDataType;

// broadcast y_scale and y_zero_point to output shape
m_inputTensorDescs[OnnxInputIndex::YScale] = TensorDesc(
yScaleDataType,
outputShape,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YScale),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);

m_inputTensorDescs[OnnxInputIndex::YZeroPoint] = TensorDesc(
yZeroPointDataType,
outputShape,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YZeroPoint),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);

// Validate input tensors
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
{
// Inputs(input tensor, scale, zero_point) are in tuple and starting from index 2
auto tupleStartIndex = 2 + inputIndex * 3;
auto xScaleDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 1).tensorDataType;
auto xZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 2).tensorDataType;
ML_CHECK_VALID_ARGUMENT(xScaleDataType == yScaleDataType, "Wrong input type encountered for scale");
ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point");

// broadcast x_scale and x_zero_point to shape of corresponding x
m_inputTensorDescs[tupleStartIndex + 1] = TensorDesc(
xScaleDataType,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 1),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);

m_inputTensorDescs[tupleStartIndex + 2] = TensorDesc(
xZeroPointDataType,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 2),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);
}

uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount(), 2);

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

// 1. output edges between Dequantize and Join node
// 2. input edge between Join and Quantize node
std::vector<TensorDesc> intermediateOutputTensorDescs(inputCount);
std::vector<DML_TENSOR_DESC> namedDequantizeOperatorDescs(inputCount);
std::vector<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC> dequantizeOperatorDescs(inputCount);
std::vector<DML_OPERATOR_DESC> dmlOpDesc(inputCount);
std::vector<const DML_OPERATOR_DESC*> opDescs;
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
{
auto tupleStartIndex = 2 + inputIndex * 3;
intermediateOutputTensorDescs[inputIndex] = TensorDesc(
MLOperatorTensorDataType::Float,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment)
);
namedDequantizeOperatorDescs[inputIndex] = intermediateOutputTensorDescs[inputIndex].GetDmlDesc();

dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex];
dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];
dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2];
dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex];

dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]};
opDescs.push_back(&dmlOpDesc[inputIndex]);
}

TensorDesc joinOutputTensorDesc = TensorDesc(
MLOperatorTensorDataType::Float,
outputShape,
outputShape,
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);
DML_TENSOR_DESC namedJoinOutputTensorDesc = joinOutputTensorDesc.GetDmlDesc();

DML_JOIN_OPERATOR_DESC joinDesc = {};
joinDesc.InputCount = gsl::narrow_cast<uint32_t>(namedDequantizeOperatorDescs.size());
joinDesc.InputTensors = namedDequantizeOperatorDescs.data();
joinDesc.OutputTensor = &namedJoinOutputTensorDesc;
joinDesc.Axis = dmlAxis;

const DML_OPERATOR_DESC opJoinDesc = {DML_OPERATOR_JOIN, &joinDesc};
opDescs.push_back(&opJoinDesc);

DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC quantizeOperatorDesc = {};
quantizeOperatorDesc.InputTensor = joinDesc.OutputTensor;
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale];
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint];
quantizeOperatorDesc.OutputTensor = &outputDescs[0];
const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
opDescs.push_back(&opQuantizeDesc);

MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.nodeCount = static_cast<uint32_t>(opDescs.size());
operatorGraphDesc.nodesAsOpDesc = opDescs.data();

uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2;
uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1;

std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
// Input edges to Dequantize nodes
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
{
auto tupleStartIndex = 2 + inputIndex * 3;
for (auto edge_index = 0; edge_index < 3; ++edge_index)
{
DML_INPUT_GRAPH_EDGE_DESC inputEdge = {};
inputEdge.GraphInputIndex = tupleStartIndex + edge_index;
inputEdge.ToNodeIndex = inputIndex;
inputEdge.ToNodeInputIndex = edge_index;
inputEdges.push_back(inputEdge);
}
}

// Input edge from y_scale to quantize node
DML_INPUT_GRAPH_EDGE_DESC yScaleInputEdge = {};
yScaleInputEdge.GraphInputIndex = 0; // Y_scale
yScaleInputEdge.ToNodeIndex = quantizeNodeIndex;
yScaleInputEdge.ToNodeInputIndex = 1;
inputEdges.push_back(yScaleInputEdge);

// Input edge from y_zero_point to quantize node
DML_INPUT_GRAPH_EDGE_DESC yZeroPointInputEdge = {};
yZeroPointInputEdge.GraphInputIndex = 1; // Y_zero_point
yZeroPointInputEdge.ToNodeIndex = quantizeNodeIndex;
yZeroPointInputEdge.ToNodeInputIndex = 2;
inputEdges.push_back(yZeroPointInputEdge);

operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();

// set intermediate edges
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
{
DML_INTERMEDIATE_GRAPH_EDGE_DESC dequantizeToJoinEdge = {};
dequantizeToJoinEdge.FromNodeIndex = inputIndex;
dequantizeToJoinEdge.FromNodeOutputIndex = 0;
dequantizeToJoinEdge.ToNodeIndex = joinNodeIndex; // The second last node Join
dequantizeToJoinEdge.ToNodeInputIndex = inputIndex;
intermediateEdges.push_back(dequantizeToJoinEdge);
}

DML_INTERMEDIATE_GRAPH_EDGE_DESC joinToQuantizeEdge = {};
joinToQuantizeEdge.FromNodeIndex = joinNodeIndex;
joinToQuantizeEdge.FromNodeOutputIndex = 0;
joinToQuantizeEdge.ToNodeIndex = quantizeNodeIndex; // The second last node Join
joinToQuantizeEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(joinToQuantizeEdge);

operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();

// set the output edges
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
outputEdge.FromNodeIndex = quantizeNodeIndex;
outputEdge.FromNodeOutputIndex = 0;
outputEdge.GraphOutputIndex = 0;
outputEdges.push_back(outputEdge);
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();

SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
};
};

DML_OP_DEFINE_CREATION_FUNCTION(QLinearConcat, DmlOperatorQLinearConcat);
} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ScatterND);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat);
DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear);
DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger);
DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger);
Expand Down Expand Up @@ -547,6 +548,7 @@ constexpr static std::array<const char*, 2> typeNameListEyeLike = { "T1", "T2" }
constexpr static std::array<const char*, 2> typeNameShape = { "T", "T1" };
constexpr static std::array<const char*, 2> typeNameSize = { "T", "T1" };
constexpr static std::array<const char*, 2> typeNameListGroupNorm = {"T", "M"};
constexpr static std::array<const char*, 3> typeNameListQLinearConcat= {"TF", "T8", "TV"};

constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAll = {SupportedTensorDataTypes::All};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32};
Expand Down Expand Up @@ -618,7 +620,18 @@ constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinea

constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDynamicQuantizeLinear = {
SupportedTensorDataTypes::Float32,
SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Ints8Bit
};

constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDynamicQuantizeMatMul= {
SupportedTensorDataTypes::Float32,
SupportedTensorDataTypes::Ints8Bit,
};

constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearConcat= {
SupportedTensorDataTypes::Float32,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32,
};

template<typename... Args>
Expand Down Expand Up @@ -1012,6 +1025,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearConcat, typeNameListQLinearConcat, supportedTypeListQLinearConcat, DmlGraphSupport::Supported)},

{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,9 @@ namespace Dml

} // namespace FusionHelpers

uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount)
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex)
{
const std::vector<DimensionType> inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
const std::vector<DimensionType> inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(firstInputIndex);
uint32_t onnxDimCount = gsl::narrow_cast<uint32_t>(inputDimensions.size());
onnxAxis = HandleNegativeAxis(onnxAxis, onnxDimCount);
return GetDmlAdjustedAxis(onnxAxis, onnxDimCount, dmlDimCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ namespace Dml
} // namespace FusionHelpers

// Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced.
// Note this function presumes the axis attribute is relative to the first input tensor (which is always the case).
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount);
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex = 0);

uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, uint32_t onnxDimCount, uint32_t dmlDimCount);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ namespace OperatorHelper
return { std::move(outputShape) };
}

void ConcatHelper::Initialize(
void ConcatHelperBase::Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span<const DimensionType> inputDimensions
)
Expand All @@ -1872,13 +1872,13 @@ namespace OperatorHelper
ML_CHECK_VALID_ARGUMENT(m_axis < static_cast<int>(inputDimensions.size()));
}

std::vector<EdgeShapes> ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
std::vector<EdgeShapes> ConcatHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const
{
auto outputShape = shapeInfo.GetInputTensorShape(0);
auto outputShape = shapeInfo.GetInputTensorShape(firstInputIndex);

uint32_t inputCount = shapeInfo.GetInputCount();

for (uint32_t i = 1; i < inputCount; ++i)
for (uint32_t i = firstInputIndex + step; i < inputCount; i += step)
{
auto inputShape = shapeInfo.GetInputTensorShape(i);
for (size_t j = 0; j < outputShape.size(); ++j)
Expand All @@ -1893,6 +1893,16 @@ namespace OperatorHelper
return { EdgeShapes(outputShape) };
}

std::vector<EdgeShapes> ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
return ConcatHelperBase::GetOutputShapes(shapeInfo, 0, 1);
}

std::vector<EdgeShapes> QLinearConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
return ConcatHelperBase::GetOutputShapes(shapeInfo, 2, 3);
}

void CropHelper::Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span<const DimensionType> inputDimensions
Expand Down
Loading

0 comments on commit 9ff5e3b

Please sign in to comment.