diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index 10c53e4c8fb0a..440d07a736763 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -496,7 +496,7 @@ class DmlOperatorElementwisePow : public DmlOperator template class DmlOperatorElementwiseQLinear : public DmlOperator { - enum OnnxInputIndex : uint32_t + enum OnnxInputIndex : uint32_t { inputIndex, scaleIndex, @@ -507,8 +507,8 @@ class DmlOperatorElementwiseQLinear : public DmlOperator public: DmlOperatorElementwiseQLinear(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) { - - ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2); + + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); Initialize(kernelInfo, std::nullopt, std::nullopt); @@ -532,8 +532,14 @@ class DmlOperatorElementwiseQLinear : public DmlOperator axis = Dml::HandleNegativeAxis(signedAxis, outputShapeDimCount, /*validateAxis*/ false); } + // Explicitly reshape each of the inputs after the first input (scale tensor and optional zero point tensor). for (uint32_t index = 1, inputCount = gsl::narrow_cast(m_inputTensorDescs.size()); index < inputCount; ++index) { + if(!kernelInfo.IsInputValid(index)) + { + continue; + } + auto edgeDesc = kernelInfo.GetInputEdgeDescription(index); assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor); @@ -557,7 +563,7 @@ class DmlOperatorElementwiseQLinear : public DmlOperator m_inputTensorDescs[index] = TensorDesc( edgeDesc.tensorDataType, - gsl::make_span(outputShape), + gsl::make_span(outputShape), gsl::make_span(inputTensorShape), TensorAxis::DoNotCoerce, TensorAxis::W, @@ -586,7 +592,7 @@ class DmlOperatorElementwiseQLinear : public DmlOperator auto inputSizes = m_inputTensorDescs[0].GetSizes(); TensorDesc intermediateOutputTensorDesc = TensorDesc( GetMlDataTypeFromDmlDataType(inputDataType), - inputSizes, + inputSizes, inputSizes, TensorAxis::DoNotCoerce, TensorAxis::W, @@ -595,7 +601,7 @@ class DmlOperatorElementwiseQLinear : public DmlOperator 0 // guaranteedBaseOffsetAlignment ); DML_TENSOR_DESC namedIntermediateOutputTensorDesc = intermediateOutputTensorDesc.GetDmlDesc(); - + // Create a tensor full of zeros DML_FILL_VALUE_CONSTANT_OPERATOR_DESC zerosDesc = {}; zerosDesc.ValueDataType = inputDataType; @@ -610,14 +616,14 @@ class DmlOperatorElementwiseQLinear : public DmlOperator const DML_OPERATOR_DESC opDesc2 = { ApiTraits::OperatorDescTraits::Type, &qLinearDesc}; // Construct the graph - std::vector inputEdges; - std::vector intermediateEdges; - std::vector outputEdges; + DML_INPUT_GRAPH_EDGE_DESC inputEdges[2]; + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdges[1]; + DML_OUTPUT_GRAPH_EDGE_DESC outputEdges[1]; MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = 2; - std::vector opDescs{&opDesc1, &opDesc2}; - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + const DML_OPERATOR_DESC* opDescs[] = {&opDesc1, &opDesc2}; + operatorGraphDesc.nodesAsOpDesc = std::data(opDescs); const uint32_t fillValueNodeIndex = 0; const uint32_t dequantizeNodeIndex = 1; @@ -625,39 +631,40 @@ class DmlOperatorElementwiseQLinear : public DmlOperator // Input edges DML_INPUT_GRAPH_EDGE_DESC inputToDequantizeEdge = {}; inputToDequantizeEdge.GraphInputIndex = OnnxInputIndex::inputIndex; - inputToDequantizeEdge.ToNodeIndex = dequantizeNodeIndex; + inputToDequantizeEdge.ToNodeIndex = dequantizeNodeIndex; inputToDequantizeEdge.ToNodeInputIndex = 0; - inputEdges.push_back(inputToDequantizeEdge); + inputEdges[0] = inputToDequantizeEdge; DML_INPUT_GRAPH_EDGE_DESC scaleToDequantizeEdge = {}; - scaleToDequantizeEdge.GraphInputIndex = OnnxInputIndex::scaleIndex; // dmlWeightsIndex - scaleToDequantizeEdge.ToNodeIndex = dequantizeNodeIndex; //dequantizeinputindex + scaleToDequantizeEdge.GraphInputIndex = OnnxInputIndex::scaleIndex; + scaleToDequantizeEdge.ToNodeIndex = dequantizeNodeIndex; scaleToDequantizeEdge.ToNodeInputIndex = 1; - inputEdges.push_back(scaleToDequantizeEdge); - - operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); - operatorGraphDesc.inputEdges = inputEdges.data(); + inputEdges[1] = scaleToDequantizeEdge; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(std::size(inputEdges)); + operatorGraphDesc.inputEdges = std::data(inputEdges); + // intermediate edges DML_INTERMEDIATE_GRAPH_EDGE_DESC fillValueToDequantizeEdge = {}; fillValueToDequantizeEdge.FromNodeIndex = 0; fillValueToDequantizeEdge.FromNodeOutputIndex = 0; fillValueToDequantizeEdge.ToNodeIndex = 1; - fillValueToDequantizeEdge.ToNodeInputIndex = zeroPointIndex; // 2 - intermediateEdges.push_back(fillValueToDequantizeEdge); + fillValueToDequantizeEdge.ToNodeInputIndex = zeroPointIndex; + intermediateEdges[0] = fillValueToDequantizeEdge; + + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(std::size(intermediateEdges)); + operatorGraphDesc.intermediateEdges = std::data(intermediateEdges); - operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); - operatorGraphDesc.intermediateEdges = intermediateEdges.data(); - // output edges DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {}; outputEdge.FromNodeIndex = 1; outputEdge.FromNodeOutputIndex = 0; outputEdge.GraphOutputIndex = 0; - outputEdges.push_back(outputEdge); - operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); - operatorGraphDesc.outputEdges = outputEdges.data(); - + outputEdges[0] = outputEdge; + + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(std::size(outputEdges)); + operatorGraphDesc.outputEdges = std::data(outputEdges); + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); } }