From b876d39e55e72d0f45bd9a5ae206ffd151455216 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 15 Dec 2023 22:47:24 -0800 Subject: [PATCH 1/2] [DML EP] Support split hidden size for RotaryEmbedding --- .../Operators/DmlOperatorRotaryEmbedding.cpp | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 30c339b845b36..f85949547ff96 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -43,6 +43,10 @@ class DmlOperatorRotaryEmbedding : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + // When the input is 4D, it has the shape [batchSize, numHeads, sequenceLength, headSize]. Otherwise, + // it has the shape [batchSize, sequenceLength, hiddenSize] + const bool inputIs4D = kernelInfo.GetInputTensorDimensionCount(inputDataIndex) == 4; + // When positionIds is a scalar, it represents the start offset for each sequence const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1; @@ -63,9 +67,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); - const uint32_t batchSize = inputDataSizes[1]; + const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1]; const uint32_t sequenceLength = inputDataSizes[2]; - const uint32_t numHeads = inputDataSizes[3] / headSize; + const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize; const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; @@ -80,8 +84,8 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector inputDescs = GetDmlInputDescs(); const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - // Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle - const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; + // We can collapse this size into a 1D tensor since it's only used as the input/output of elementwise operations + const std::array inputOutputShape = {batchSize * sequenceLength * numHeads * headSize}; TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); @@ -104,8 +108,20 @@ class DmlOperatorRotaryEmbedding : public DmlOperator : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + + // We need to stride it if the input was 4D + if (inputIs4D) + { + const std::vector inputDataStrides = interleaved + ? std::vector({sequenceLength * numHeads * headSize, headSize, sequenceLength * headSize, 2, 1}) + : std::vector({sequenceLength * numHeads * headSize, headSize, sequenceLength * headSize, headSize / 2, 1}); + } + const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); + TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc(); + TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; @@ -122,7 +138,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // Swap the 2 halves and join them together DML_JOIN_OPERATOR_DESC joinInputDesc{}; joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); - joinInputDesc.OutputTensor = &inputDataDmlTensorDesc; + joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc; joinInputDesc.Axis = splitInputDesc.Axis; joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; @@ -212,16 +228,16 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; - mulSignDesc.ATensor = &inputDataDmlTensorDesc; + mulSignDesc.ATensor = &joinedDataDmlTensorDesc; mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc; - mulSignDesc.OutputTensor = &inputDataDmlTensorDesc; + mulSignDesc.OutputTensor = &joinedDataDmlTensorDesc; const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc}; // Multiply the non-rotated data with the cos and the rotated data with the sin DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{}; - mulCosSinDesc.ATensor = &inputDataDmlTensorDesc; + mulCosSinDesc.ATensor = &joinedDataDmlTensorDesc; mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc; - mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc; + mulCosSinDesc.OutputTensor = &joinedDataDmlTensorDesc; const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc}; // Add the multiplied cos and sin values together From 51716affb3bb6ef1ff61aefb5ec3d09eb4cda36a Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Thu, 18 Jan 2024 10:44:37 -0800 Subject: [PATCH 2/2] Add stride support --- .../Operators/DmlOperatorRotaryEmbedding.cpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index f85949547ff96..44004b5d77f70 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -84,16 +84,24 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector inputDescs = GetDmlInputDescs(); const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - // We can collapse this size into a 1D tensor since it's only used as the input/output of elementwise operations - const std::array inputOutputShape = {batchSize * sequenceLength * numHeads * headSize}; + const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + + if (inputIs4D) + { + const std::array inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1}; + stridedInputOutputTensorDesc.SetStrides(inputOutputStrides); + } + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc(); // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; - copyInputDesc.InputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc; copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; copyInputDesc.ScaleBias = &scaleBias; const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; @@ -109,14 +117,6 @@ class DmlOperatorRotaryEmbedding : public DmlOperator TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); - // We need to stride it if the input was 4D - if (inputIs4D) - { - const std::vector inputDataStrides = interleaved - ? std::vector({sequenceLength * numHeads * headSize, headSize, sequenceLength * headSize, 2, 1}) - : std::vector({sequenceLength * numHeads * headSize, headSize, sequenceLength * headSize, headSize / 2, 1}); - } - const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); @@ -244,7 +244,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; addDesc.ATensor = &inputOutputDmlTensorDesc; addDesc.BTensor = &inputOutputDmlTensorDesc; - addDesc.OutputTensor = &inputOutputDmlTensorDesc; + addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc; const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; // Construct the graph