Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Linnea May committed Aug 11, 2023
1 parent 651bb22 commit c85678d
Showing 1 changed file with 35 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ class DmlOperatorElementwisePow : public DmlOperator
template <typename TOperatorDesc>
class DmlOperatorElementwiseQLinear : public DmlOperator
{
enum OnnxInputIndex : uint32_t
enum OnnxInputIndex : uint32_t
{
inputIndex,
scaleIndex,
Expand All @@ -507,8 +507,8 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
public:
DmlOperatorElementwiseQLinear(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2);

ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);

Initialize(kernelInfo, std::nullopt, std::nullopt);
Expand All @@ -532,8 +532,14 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
axis = Dml::HandleNegativeAxis(signedAxis, outputShapeDimCount, /*validateAxis*/ false);
}

// Explicitly reshape each of the inputs after the first input (scale tensor and optional zero point tensor).
for (uint32_t index = 1, inputCount = gsl::narrow_cast<uint32_t>(m_inputTensorDescs.size()); index < inputCount; ++index)
{
if(!kernelInfo.IsInputValid(index))
{
continue;
}

auto edgeDesc = kernelInfo.GetInputEdgeDescription(index);
assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor);

Expand All @@ -557,7 +563,7 @@ class DmlOperatorElementwiseQLinear : public DmlOperator

m_inputTensorDescs[index] = TensorDesc(
edgeDesc.tensorDataType,
gsl::make_span(outputShape),
gsl::make_span(outputShape),
gsl::make_span(inputTensorShape),
TensorAxis::DoNotCoerce,
TensorAxis::W,
Expand Down Expand Up @@ -586,7 +592,7 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
auto inputSizes = m_inputTensorDescs[0].GetSizes();
TensorDesc intermediateOutputTensorDesc = TensorDesc(
GetMlDataTypeFromDmlDataType(inputDataType),
inputSizes,
inputSizes,
inputSizes,
TensorAxis::DoNotCoerce,
TensorAxis::W,
Expand All @@ -595,7 +601,7 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
0 // guaranteedBaseOffsetAlignment
);
DML_TENSOR_DESC namedIntermediateOutputTensorDesc = intermediateOutputTensorDesc.GetDmlDesc();

// Create a tensor full of zeros
DML_FILL_VALUE_CONSTANT_OPERATOR_DESC zerosDesc = {};
zerosDesc.ValueDataType = inputDataType;
Expand All @@ -610,54 +616,55 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
const DML_OPERATOR_DESC opDesc2 = { ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &qLinearDesc};

// Construct the graph
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
DML_INPUT_GRAPH_EDGE_DESC inputEdges[2];
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdges[1];
DML_OUTPUT_GRAPH_EDGE_DESC outputEdges[1];

MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.nodeCount = 2;
std::vector<const DML_OPERATOR_DESC*> opDescs{&opDesc1, &opDesc2};
operatorGraphDesc.nodesAsOpDesc = opDescs.data();
const DML_OPERATOR_DESC* opDescs[] = {&opDesc1, &opDesc2};
operatorGraphDesc.nodesAsOpDesc = std::data(opDescs);

const uint32_t fillValueNodeIndex = 0;
const uint32_t dequantizeNodeIndex = 1;

// Input edges
DML_INPUT_GRAPH_EDGE_DESC inputToDequantizeEdge = {};
inputToDequantizeEdge.GraphInputIndex = OnnxInputIndex::inputIndex;
inputToDequantizeEdge.ToNodeIndex = dequantizeNodeIndex;
inputToDequantizeEdge.ToNodeIndex = dequantizeNodeIndex;
inputToDequantizeEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputToDequantizeEdge);
inputEdges[0] = inputToDequantizeEdge;

DML_INPUT_GRAPH_EDGE_DESC scaleToDequantizeEdge = {};
scaleToDequantizeEdge.GraphInputIndex = OnnxInputIndex::scaleIndex; // dmlWeightsIndex
scaleToDequantizeEdge.ToNodeIndex = dequantizeNodeIndex; //dequantizeinputindex
scaleToDequantizeEdge.GraphInputIndex = OnnxInputIndex::scaleIndex;
scaleToDequantizeEdge.ToNodeIndex = dequantizeNodeIndex;
scaleToDequantizeEdge.ToNodeInputIndex = 1;
inputEdges.push_back(scaleToDequantizeEdge);

operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();
inputEdges[1] = scaleToDequantizeEdge;

operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(std::size(inputEdges));
operatorGraphDesc.inputEdges = std::data(inputEdges);

// intermediate edges
DML_INTERMEDIATE_GRAPH_EDGE_DESC fillValueToDequantizeEdge = {};
fillValueToDequantizeEdge.FromNodeIndex = 0;
fillValueToDequantizeEdge.FromNodeOutputIndex = 0;
fillValueToDequantizeEdge.ToNodeIndex = 1;
fillValueToDequantizeEdge.ToNodeInputIndex = zeroPointIndex; // 2
intermediateEdges.push_back(fillValueToDequantizeEdge);
fillValueToDequantizeEdge.ToNodeInputIndex = zeroPointIndex;
intermediateEdges[0] = fillValueToDequantizeEdge;

operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(std::size(intermediateEdges));
operatorGraphDesc.intermediateEdges = std::data(intermediateEdges);

operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();

// output edges
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
outputEdge.FromNodeIndex = 1;
outputEdge.FromNodeOutputIndex = 0;
outputEdge.GraphOutputIndex = 0;
outputEdges.push_back(outputEdge);
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();

outputEdges[0] = outputEdge;

operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(std::size(outputEdges));
operatorGraphDesc.outputEdges = std::data(outputEdges);

SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo);
}
}
Expand Down

0 comments on commit c85678d

Please sign in to comment.