Skip to content

Commit

Permalink
add zero point tensor if not already defined
Browse files Browse the repository at this point in the history
  • Loading branch information
Linnea May committed Aug 2, 2023
1 parent c17ed2d commit da9afec
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -496,16 +496,27 @@ class DmlOperatorElementwisePow : public DmlOperator
template <typename TOperatorDesc>
class DmlOperatorElementwiseQLinear : public DmlOperator
{
enum OnnxInputIndex : uint32_t
{
inputIndex,
scaleIndex,
zeroPointIndex,
inputCount
};

public:
DmlOperatorElementwiseQLinear(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 3); // TODO: Can be 2 inputs, since x_zero_point is optional.

ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);

Initialize(kernelInfo, std::nullopt, std::nullopt);

std::vector<uint32_t> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
const uint32_t outputShapeDimCount = gsl::narrow_cast<uint32_t>(outputShape.size());

Initialize(kernelInfo, std::nullopt, std::nullopt);
const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType();
bool hasZeroPointTensor = kernelInfo.IsInputValid(OnnxInputIndex::zeroPointIndex);

uint32_t axis = 0;

Expand All @@ -521,7 +532,6 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
axis = Dml::HandleNegativeAxis(signedAxis, outputShapeDimCount, /*validateAxis*/ false);
}

// Explicitly reshape each of the inputs after the first input (scale and zero point tensors).
for (uint32_t index = 1, inputCount = gsl::narrow_cast<uint32_t>(m_inputTensorDescs.size()); index < inputCount; ++index)
{
auto edgeDesc = kernelInfo.GetInputEdgeDescription(index);
Expand All @@ -547,7 +557,7 @@ class DmlOperatorElementwiseQLinear : public DmlOperator

m_inputTensorDescs[index] = TensorDesc(
edgeDesc.tensorDataType,
gsl::make_span(outputShape),
gsl::make_span(outputShape),
gsl::make_span(inputTensorShape),
TensorAxis::DoNotCoerce,
TensorAxis::W,
Expand All @@ -560,13 +570,96 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

TOperatorDesc opDesc = {};
opDesc.InputTensor = &inputDescs[0];
opDesc.ScaleTensor = &inputDescs[1];
opDesc.ZeroPointTensor = &inputDescs[2];
opDesc.OutputTensor = &outputDescs[0];
if (hasZeroPointTensor)
{
TOperatorDesc opDesc = {};
opDesc.InputTensor = &inputDescs[OnnxInputIndex::inputIndex];
opDesc.ScaleTensor = &inputDescs[OnnxInputIndex::scaleIndex];
opDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::zeroPointIndex];
opDesc.OutputTensor = &outputDescs[0];

SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
}
// Create a zero point tensor, since it's a required input by DML
else
{
auto inputSizes = m_inputTensorDescs[0].GetSizes();
TensorDesc intermediateOutputTensorDesc = TensorDesc(
GetMlDataTypeFromDmlDataType(inputDataType),
inputSizes,
inputSizes,
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);
DML_TENSOR_DESC namedIntermediateOutputTensorDesc = intermediateOutputTensorDesc.GetDmlDesc();

// Create a tensor full of zeros
DML_FILL_VALUE_CONSTANT_OPERATOR_DESC zerosDesc = {};
zerosDesc.ValueDataType = inputDataType;
zerosDesc.OutputTensor = &namedIntermediateOutputTensorDesc;
const DML_OPERATOR_DESC opDesc1 = { DML_OPERATOR_FILL_VALUE_CONSTANT, &zerosDesc };

TOperatorDesc qLinearDesc = {};
qLinearDesc.InputTensor = &inputDescs[OnnxInputIndex::inputIndex];
qLinearDesc.ScaleTensor = &inputDescs[OnnxInputIndex::scaleIndex];
qLinearDesc.ZeroPointTensor = zerosDesc.OutputTensor;
qLinearDesc.OutputTensor = &outputDescs[0];
const DML_OPERATOR_DESC opDesc2 = { ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &qLinearDesc};

// Construct the graph
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;

MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.nodeCount = 2;
std::vector<const DML_OPERATOR_DESC*> opDescs{&opDesc1, &opDesc2};
operatorGraphDesc.nodesAsOpDesc = opDescs.data();

const uint32_t fillValueNodeIndex = 0;
const uint32_t dequantizeNodeIndex = 1;

// Input edges
DML_INPUT_GRAPH_EDGE_DESC inputToDequantizeEdge = {};
inputToDequantizeEdge.GraphInputIndex = OnnxInputIndex::inputIndex;
inputToDequantizeEdge.ToNodeIndex = dequantizeNodeIndex;
inputToDequantizeEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputToDequantizeEdge);

DML_INPUT_GRAPH_EDGE_DESC scaleToDequantizeEdge = {};
scaleToDequantizeEdge.GraphInputIndex = OnnxInputIndex::scaleIndex; // dmlWeightsIndex
scaleToDequantizeEdge.ToNodeIndex = dequantizeNodeIndex; //dequantizeinputindex
scaleToDequantizeEdge.ToNodeInputIndex = 1;
inputEdges.push_back(scaleToDequantizeEdge);

operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();

// intermediate edges
DML_INTERMEDIATE_GRAPH_EDGE_DESC fillValueToDequantizeEdge = {};
fillValueToDequantizeEdge.FromNodeIndex = 0;
fillValueToDequantizeEdge.FromNodeOutputIndex = 0;
fillValueToDequantizeEdge.ToNodeIndex = 1;
fillValueToDequantizeEdge.ToNodeInputIndex = zeroPointIndex; // 2
intermediateEdges.push_back(fillValueToDequantizeEdge);

operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();

// output edges
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
outputEdge.FromNodeIndex = 1;
outputEdge.FromNodeOutputIndex = 0;
outputEdge.GraphOutputIndex = 0;
outputEdges.push_back(outputEdge);
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();

SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo);
}
}
};

Expand Down
10 changes: 0 additions & 10 deletions onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ TEST(DequantizeLinearOpTest, Int8) {

// scalar zero & scale with int8
TEST(DequantizeLinearOpTest, Int32) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect";
}

OpTester test("DequantizeLinear", 10);
std::vector<int64_t> dims{4};
test.AddInput<int32_t>("x", dims, {-30, -3, 100, 127});
Expand Down Expand Up @@ -77,11 +72,6 @@ TEST(DequantizeLinearOpTest, Scalar) {

// dequantize without zero point
TEST(DequantizeLinearOpTest, Without_Zero_Point) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect";
}

OpTester test("DequantizeLinear", 10);
test.AddInput<int8_t>("x", {}, {100});
test.AddInput<float>("x_scale", {}, {2.0f});
Expand Down

0 comments on commit da9afec

Please sign in to comment.