Skip to content

Commit

Permalink
implement DynamicQuantizeMatMul (#16757) (#18237)
Browse files Browse the repository at this point in the history
[Cherry Pick Reviewed]
This PR implement

[com.microsoft.DynamicQuantizeMatMul](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftdynamicquantizematmul)

![image](https://github.com/microsoft/onnxruntime/assets/17421593/c8ab927a-5d69-40e5-a08b-79b89becf937)

<!-- Describe your changes. -->

<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Xiang Zhang <[email protected]>
  • Loading branch information
raoanag and zhangxiang1993 committed Mar 4, 2024
1 parent 27b1dc9 commit d0fbd58
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1454,6 +1454,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_DEQUANTIZ
using DescType = DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC;
};

template <>
struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CONVOLUTION>
{
Expand Down Expand Up @@ -2221,6 +2227,11 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SWISH>
{
using DescType = DML_ACTIVATION_SWISH_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT>
{
using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SWISH>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,25 @@ constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA {
DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA_FIELDS,
};

constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};

constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT),
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
8,
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
};

constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,19 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_SHRINK_OPERATOR
OperatorField(&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<FLOAT>(desc.Threshold))),
};
}
inline std::vector<OperatorField> GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AScaleTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AZeroPointTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BScaleTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BZeroPointTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_GELU_OPERATOR_DESC& desc)
{
return {
Expand Down Expand Up @@ -2444,6 +2457,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
return AbstractOperatorDesc(
&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ROI_ALIGN_GRAD:
return AbstractOperatorDesc(
&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "precomp.h"

namespace Dml
{
// DynamicQuantizeMatMul = MatrixMultiplyIntegerToFloat(DynamicQuantizeLinear(A), B)
class DmlOperatorDynamicQuantizeMatMul : public DmlOperator
{
// This order matches the ONNX schema.
enum OnnxInputIndex
{
A, // Input
B,
B_scale,
B_zero_point,
Bias,
Count,
};

public:
DmlOperatorDynamicQuantizeMatMul(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
DmlOperator::Initialize(kernelCreationContext);

const bool hasBias = kernelCreationContext.IsInputValid(OnnxInputIndex::Bias);
const bool hasBZP = kernelCreationContext.IsInputValid(OnnxInputIndex::B_zero_point);

// Broadcast Bias tensor to the shape of the output tensor.
if (hasBias)
{
m_inputTensorDescs[OnnxInputIndex::Bias] = CreateTensorDescFromInput(
kernelCreationContext,
OnnxInputIndex::Bias,
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0)
);
}
MLOperatorTensorDataType BDatatype = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::B).tensorDataType;

std::vector<uint32_t> ATensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::A);
std::vector<uint32_t> ExpectedAScaleTensorShape = {1, 1, 1, 1};
std::vector<uint32_t> ExpectedAZeroPointTensorShape = {1, 1, 1, 1};

// output edges between DynQL and MMItoFloat node
TensorDesc intermediateQuantizedATensorDesc = TensorDesc(
BDatatype,
gsl::make_span(ATensorShape),
gsl::make_span(ATensorShape),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);

TensorDesc intermediateQuantizedAScaleTensorDesc = TensorDesc(
MLOperatorTensorDataType::Float,
gsl::make_span(ExpectedAScaleTensorShape),
gsl::make_span(ExpectedAScaleTensorShape),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);

TensorDesc intermediateQuantizedAZeroPointTensorDesc = TensorDesc(
BDatatype,
gsl::make_span(ExpectedAZeroPointTensorShape),
gsl::make_span(ExpectedAZeroPointTensorShape),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);

DML_TENSOR_DESC namedIntermediateQuantizedATensorDesc = intermediateQuantizedATensorDesc.GetDmlDesc();
DML_TENSOR_DESC namedIntermediateQuantizedAScaleTensorDesc = intermediateQuantizedAScaleTensorDesc.GetDmlDesc();
DML_TENSOR_DESC namedIntermediateQuantizedAZeroPointTensorDesc = intermediateQuantizedAZeroPointTensorDesc.GetDmlDesc();

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC dynamicQuantizeLinearOperatorDesc = {};
dynamicQuantizeLinearOperatorDesc.InputTensor = &inputDescs[OnnxInputIndex::A];
dynamicQuantizeLinearOperatorDesc.OutputTensor = &namedIntermediateQuantizedATensorDesc;
dynamicQuantizeLinearOperatorDesc.OutputScaleTensor = &namedIntermediateQuantizedAScaleTensorDesc;
dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor = &namedIntermediateQuantizedAZeroPointTensorDesc;

const DML_OPERATOR_DESC opDesc1{DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, &dynamicQuantizeLinearOperatorDesc};

DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matrixMultiplyIntergerToFloatOperatorDesc = {};
matrixMultiplyIntergerToFloatOperatorDesc.ATensor = dynamicQuantizeLinearOperatorDesc.OutputTensor;
matrixMultiplyIntergerToFloatOperatorDesc.AScaleTensor = dynamicQuantizeLinearOperatorDesc.OutputScaleTensor;
matrixMultiplyIntergerToFloatOperatorDesc.AZeroPointTensor = dynamicQuantizeLinearOperatorDesc.OutputZeroPointTensor;
matrixMultiplyIntergerToFloatOperatorDesc.BTensor = &inputDescs[OnnxInputIndex::B];
matrixMultiplyIntergerToFloatOperatorDesc.BScaleTensor = &inputDescs[OnnxInputIndex::B_scale];
matrixMultiplyIntergerToFloatOperatorDesc.BZeroPointTensor = hasBZP? &inputDescs[OnnxInputIndex::B_zero_point] : nullptr;
matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr;
matrixMultiplyIntergerToFloatOperatorDesc.OutputTensor = &outputDescs[0];

const DML_OPERATOR_DESC opDesc2{ (DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matrixMultiplyIntergerToFloatOperatorDesc};

MLOperatorGraphDesc operatorGraphDesc = {};
std::vector<const DML_OPERATOR_DESC*> opDescs{&opDesc1, &opDesc2};
operatorGraphDesc.nodeCount = static_cast<uint32_t>(opDescs.size());
operatorGraphDesc.nodesAsOpDesc = opDescs.data();

// set input edges
std::pair<uint32_t, uint32_t> nodeToNodeInputIndex[OnnxInputIndex::Count] {{0, 0}, {1, 3}, {1, 4}, {1, 5}, {1, 6}};
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
for (uint32_t inputIndex = 0; inputIndex < OnnxInputIndex::Count; inputIndex++)
{
if (inputIndex == OnnxInputIndex::B_zero_point && !hasBZP) continue;
if (inputIndex == OnnxInputIndex::Bias && !hasBias) continue;
DML_INPUT_GRAPH_EDGE_DESC inputEdge = {};
inputEdge.GraphInputIndex = inputIndex; // OnnxInputIndex and DmlInputIndex are identity for QLinearSigmoid
inputEdge.ToNodeIndex = nodeToNodeInputIndex[inputIndex].first;
inputEdge.ToNodeInputIndex = nodeToNodeInputIndex[inputIndex].second;
inputEdges.push_back(inputEdge);
}
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();

// set intermediate edges
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;

DML_INTERMEDIATE_GRAPH_EDGE_DESC dynQLToMMItofloatEdge1 = {};
dynQLToMMItofloatEdge1.FromNodeIndex = 0;
dynQLToMMItofloatEdge1.FromNodeOutputIndex = 0;
dynQLToMMItofloatEdge1.ToNodeIndex = 1;
dynQLToMMItofloatEdge1.ToNodeInputIndex = 0;
intermediateEdges.push_back(dynQLToMMItofloatEdge1);

DML_INTERMEDIATE_GRAPH_EDGE_DESC dynQLToMMItofloatEdge2 = {};
dynQLToMMItofloatEdge2.FromNodeIndex = 0;
dynQLToMMItofloatEdge2.FromNodeOutputIndex = 1;
dynQLToMMItofloatEdge2.ToNodeIndex = 1;
dynQLToMMItofloatEdge2.ToNodeInputIndex = 1;
intermediateEdges.push_back(dynQLToMMItofloatEdge2);

DML_INTERMEDIATE_GRAPH_EDGE_DESC dynQLToMMItofloatEdge3 = {};
dynQLToMMItofloatEdge3.FromNodeIndex = 0;
dynQLToMMItofloatEdge3.FromNodeOutputIndex = 2;
dynQLToMMItofloatEdge3.ToNodeIndex = 1;
dynQLToMMItofloatEdge3.ToNodeInputIndex = 2;
intermediateEdges.push_back(dynQLToMMItofloatEdge3);

operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();

// set the output edges
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
outputEdge.FromNodeIndex = 1;
outputEdge.FromNodeOutputIndex = 0;
outputEdge.GraphOutputIndex = 0;
outputEdges.push_back(outputEdge);
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();

SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
}
};

DML_OP_DEFINE_CREATION_FUNCTION(DynamicQuantizeMatMul, DmlOperatorDynamicQuantizeMatMul);
} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Dropout);
DML_OP_EXTERN_CREATION_FUNCTION(MatMul);
DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMul);
DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMulActivation);
DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeMatMul);
DML_OP_EXTERN_CREATION_FUNCTION(Cast);
DML_OP_EXTERN_CREATION_FUNCTION(CastLike15);
DML_OP_EXTERN_CREATION_FUNCTION(CastLike19);
Expand Down Expand Up @@ -1065,6 +1066,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, Gelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, BiasGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, DynamicQuantizeMatMul, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, FusedMatMulActivation, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,7 @@ using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MatMul = MatMulHelper;
using ShapeInferenceHelper_MatMulInteger = MatMulHelper;
using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulHelper;
using ShapeInferenceHelper_DynamicQuantizeMatMul = MatMulHelper;
using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper;
using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper;
Expand Down

0 comments on commit d0fbd58

Please sign in to comment.