From ffc170232b27c6be6a9d9d232aba5c146a28ef6e Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sun, 29 Oct 2023 00:37:19 -0700 Subject: [PATCH 1/8] Add RotaryEmbedding (WIP) --- .../Operators/DmlOperatorRotaryEmbedding.cpp | 312 ++++++++++++++++++ .../src/Operators/OperatorRegistration.cpp | 4 + .../dml/OperatorAuthorHelper/OperatorHelper.h | 1 + .../OperatorAuthorHelper/OperatorVersions.h | 1 + 4 files changed, 318 insertions(+) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp new file mode 100644 index 0000000000000..0a8a1cf85e87c --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -0,0 +1,312 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ +class DmlOperatorRotaryEmbedding : public DmlOperator +{ +public: + DmlOperatorRotaryEmbedding(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) + { + enum InputIndex : uint32_t + { + inputDataIndex, + positionIdsIndex, + cosCacheIndex, + sinCacheIndex, + }; + + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + + Initialize(kernelInfo); + + ComPtr contextPrivate; + ORT_THROW_IF_FAILED(kernelInfo.GetInterface()->QueryInterface(contextPrivate.GetAddressOf())); + + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[positionIdsIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[sinCacheIndex].GetDimensionCount() == 4); + + ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetSizes() == m_inputTensorDescs[sinCacheIndex].GetSizes()); + const uint32_t headSize = m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2; + + // The last dimension of the data is the hidden size, so it must be divisible by the head size + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetSizes().back() % headSize == 0); + + // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] + const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); + const uint32_t batchSize = inputDataSizes[1]; + const uint32_t sequenceLength = inputDataSizes[2]; + const uint32_t numHeads = inputDataSizes[3] / headSize; + + // Split the input data into 2 equal parts + const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; + const std::array inputDataTensorShape {batchSize, sequenceLength, numHeads, headSize}; + TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); + + const std::array splitInputDataTensorShape {batchSize, sequenceLength, numHeads, headSize / 2}; + TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); + const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; + + DML_SPLIT_OPERATOR_DESC splitInputDesc{}; + splitInputDesc.InputTensor = &inputDataDmlTensorDesc; + splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); + splitInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + splitInputDesc.Axis = gsl::narrow_cast(splitInputDataTensorShape.size()) - 1; + const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc}; + + // Gather the cos/sin values based on the position ids + const std::array gatheredCosSinShape {1, 1, sequenceLength, headSize / 2}; + TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape); + const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc(); + + DML_GATHER_OPERATOR_DESC gatherCosSinDesc{}; + gatherCosSinDesc.InputTensor = &inputDescs[cosCacheIndex]; + gatherCosSinDesc.IndicesTensor = &inputDescs[positionIdsIndex]; + gatherCosSinDesc.OutputTensor = &gatheredCosSinDmlTensorDesc; + gatherCosSinDesc.Axis = 2; + gatherCosSinDesc.IndexDimensions = 2; + const DML_OPERATOR_DESC gatherCosSinDmlDesc {DML_OPERATOR_GATHER, &gatherCosSinDesc}; + + // After gathering cos/sin, reshape and broadcast them to match the number of heads of the half input data + const std::array reshapedCosSinShape {1, sequenceLength, 1, headSize / 2}; + const std::array broadcastedCosSinShape {batchSize, sequenceLength, numHeads, headSize / 2}; + TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, broadcastedCosSinShape, reshapedCosSinShape); + const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc(); + + // Multiply the first half of the data with the cos and the second half of the negated data with the sin + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulHalfDataDesc{}; + mulHalfDataDesc.ATensor = &splitInputDataDmlTensorDescs.front(); + mulHalfDataDesc.BTensor = &broadcastedCosSinDmlTensorDesc; + mulHalfDataDesc.OutputTensor = &splitInputDataDmlTensorDescs.front(); + const DML_OPERATOR_DESC mulHalfDataDmlDesc {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulHalfDataDesc}; + + // Negate the second half of the data + DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC negateHalfDataDesc{}; + negateHalfDataDesc.InputTensor = &splitInputDataDmlTensorDescs.front(); + negateHalfDataDesc.OutputTensor = &splitInputDataDmlTensorDescs.front(); + const DML_OPERATOR_DESC negateHalfDataDmlDesc {DML_OPERATOR_ELEMENT_WISE_NEGATE, &negateHalfDataDesc}; + + // Add the multiplied 2 halves together + DML_ELEMENT_WISE_ADD_OPERATOR_DESC addHalfDataDesc{}; + addHalfDataDesc.ATensor = &splitInputDataDmlTensorDescs.front(); + addHalfDataDesc.BTensor = &splitInputDataDmlTensorDescs.front(); + addHalfDataDesc.OutputTensor = &splitInputDataDmlTensorDescs.front(); + const DML_OPERATOR_DESC addHalfDataDmlDesc {DML_OPERATOR_ELEMENT_WISE_ADD, &addHalfDataDesc}; + + // Join the 2 halves together + DML_JOIN_OPERATOR_DESC joinHalfDataDesc{}; + joinHalfDataDesc.InputTensors = splitInputDataDmlTensorDescs.data(); + joinHalfDataDesc.OutputTensor = &inputDataDmlTensorDesc; + joinHalfDataDesc.Axis = gsl::narrow_cast(splitInputDataTensorShape.size()) - 1; + joinHalfDataDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + const DML_OPERATOR_DESC joinHalfDataDmlDesc {DML_OPERATOR_JOIN, &joinHalfDataDesc}; + + // Construct the graph + std::vector inputEdges; + std::vector intermediateEdges; + std::vector outputEdges; + + std::array opDescs = { + &splitInputDmlDesc, // Split the input data + &gatherCosSinDmlDesc, // Gather cos + &gatherCosSinDmlDesc, // Gather sin + + &mulHalfDataDmlDesc, // Multiply cos with the first half of the input + &negateHalfDataDmlDesc, // Negate the second half of the input + &mulHalfDataDmlDesc, // Multiply sin with the negated second of half of the input + &addHalfDataDmlDesc, // Add the 2 halves together + + &mulHalfDataDmlDesc, // Multiply cos with the second half of the input + &mulHalfDataDmlDesc, // Multiply sin with the first half of the input + &addHalfDataDmlDesc, // Add the 2 halves together + + &joinHalfDataDmlDesc, // Join the halves together + }; + + enum NodeIndex : uint32_t + { + splitInputOpIndex, + gatherCosOpIndex, + gatherSinOpIndex, + + mulCosFirstHalfOpIndex, + negateSecondHalfOpIndex, + mulSinNegatedSecondHalfOpIndex, + addFirstHalfOpIndex, + + mulCosSecondHalfOpIndex, + mulSinFirstHalfOpIndex, + addSecondHalfOpIndex, + + joinOpIndex, + }; + + DML_INPUT_GRAPH_EDGE_DESC inputToSplitEdge = {}; + inputToSplitEdge.GraphInputIndex = inputDataIndex; + inputToSplitEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToSplitEdge); + + DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {}; + cosToGatherEdge.GraphInputIndex = cosCacheIndex; + cosToGatherEdge.ToNodeIndex = gatherCosOpIndex; + cosToGatherEdge.ToNodeInputIndex = 0; + inputEdges.push_back(cosToGatherEdge); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherCosEdge = {}; + positionIdsToGatherCosEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; + positionIdsToGatherCosEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherCosEdge); + + DML_INPUT_GRAPH_EDGE_DESC sinToGatherEdge = {}; + sinToGatherEdge.GraphInputIndex = sinCacheIndex; + sinToGatherEdge.ToNodeIndex = gatherSinOpIndex; + sinToGatherEdge.ToNodeInputIndex = 0; + inputEdges.push_back(sinToGatherEdge); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherSinEdge = {}; + positionIdsToGatherSinEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; + positionIdsToGatherSinEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherSinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToMulCosEdge = {}; + firstHalfDataToMulCosEdge.FromNodeIndex = splitInputOpIndex; + firstHalfDataToMulCosEdge.FromNodeOutputIndex = 0; + firstHalfDataToMulCosEdge.ToNodeIndex = mulCosFirstHalfOpIndex; + firstHalfDataToMulCosEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(firstHalfDataToMulCosEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC gatheredCosToMulFirstHalfDataEdge = {}; + gatheredCosToMulFirstHalfDataEdge.FromNodeIndex = gatherCosOpIndex; + gatheredCosToMulFirstHalfDataEdge.FromNodeOutputIndex = 0; + gatheredCosToMulFirstHalfDataEdge.ToNodeIndex = mulCosFirstHalfOpIndex; + gatheredCosToMulFirstHalfDataEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(gatheredCosToMulFirstHalfDataEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToNegateEdge = {}; + secondHalfDataToNegateEdge.FromNodeIndex = splitInputOpIndex; + secondHalfDataToNegateEdge.FromNodeOutputIndex = 1; + secondHalfDataToNegateEdge.ToNodeIndex = negateSecondHalfOpIndex; + secondHalfDataToNegateEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(secondHalfDataToNegateEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC negatedSecondHalfDataToMulSinEdge = {}; + negatedSecondHalfDataToMulSinEdge.FromNodeIndex = negateSecondHalfOpIndex; + negatedSecondHalfDataToMulSinEdge.FromNodeOutputIndex = 0; + negatedSecondHalfDataToMulSinEdge.ToNodeIndex = mulSinNegatedSecondHalfOpIndex; + negatedSecondHalfDataToMulSinEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(negatedSecondHalfDataToMulSinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC gatheredSinToMulNegatedSecondHalfDataEdge = {}; + gatheredSinToMulNegatedSecondHalfDataEdge.FromNodeIndex = gatherSinOpIndex; + gatheredSinToMulNegatedSecondHalfDataEdge.FromNodeOutputIndex = 0; + gatheredSinToMulNegatedSecondHalfDataEdge.ToNodeIndex = mulSinNegatedSecondHalfOpIndex; + gatheredSinToMulNegatedSecondHalfDataEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(gatheredSinToMulNegatedSecondHalfDataEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfCosMulToAddEdge = {}; + firstHalfCosMulToAddEdge.FromNodeIndex = mulCosFirstHalfOpIndex; + firstHalfCosMulToAddEdge.FromNodeOutputIndex = 0; + firstHalfCosMulToAddEdge.ToNodeIndex = addFirstHalfOpIndex; + firstHalfCosMulToAddEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(firstHalfCosMulToAddEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfSinMulToAddEdge = {}; + secondHalfSinMulToAddEdge.FromNodeIndex = mulSinNegatedSecondHalfOpIndex; + secondHalfSinMulToAddEdge.FromNodeOutputIndex = 0; + secondHalfSinMulToAddEdge.ToNodeIndex = addFirstHalfOpIndex; + secondHalfSinMulToAddEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(secondHalfSinMulToAddEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToMulCosEdge = {}; + secondHalfDataToMulCosEdge.FromNodeIndex = splitInputOpIndex; + secondHalfDataToMulCosEdge.FromNodeOutputIndex = 1; + secondHalfDataToMulCosEdge.ToNodeIndex = mulCosSecondHalfOpIndex; + secondHalfDataToMulCosEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(secondHalfDataToMulCosEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC gatheredCosToMulSecondHalfDataEdge = {}; + gatheredCosToMulSecondHalfDataEdge.FromNodeIndex = gatherCosOpIndex; + gatheredCosToMulSecondHalfDataEdge.FromNodeOutputIndex = 0; + gatheredCosToMulSecondHalfDataEdge.ToNodeIndex = mulCosSecondHalfOpIndex; + gatheredCosToMulSecondHalfDataEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(gatheredCosToMulSecondHalfDataEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToMulSinEdge = {}; + firstHalfDataToMulSinEdge.FromNodeIndex = splitInputOpIndex; + firstHalfDataToMulSinEdge.FromNodeOutputIndex = 0; + firstHalfDataToMulSinEdge.ToNodeIndex = mulSinFirstHalfOpIndex; + firstHalfDataToMulSinEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(firstHalfDataToMulSinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC gatheredSinToMulFirstHalfDataEdge = {}; + gatheredSinToMulFirstHalfDataEdge.FromNodeIndex = gatherSinOpIndex; + gatheredSinToMulFirstHalfDataEdge.FromNodeOutputIndex = 0; + gatheredSinToMulFirstHalfDataEdge.ToNodeIndex = mulSinFirstHalfOpIndex; + gatheredSinToMulFirstHalfDataEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(gatheredSinToMulFirstHalfDataEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfCosMulToAddEdge = {}; + secondHalfCosMulToAddEdge.FromNodeIndex = mulCosSecondHalfOpIndex; + secondHalfCosMulToAddEdge.FromNodeOutputIndex = 0; + secondHalfCosMulToAddEdge.ToNodeIndex = addSecondHalfOpIndex; + secondHalfCosMulToAddEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(secondHalfCosMulToAddEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfSinMulToAddEdge = {}; + firstHalfSinMulToAddEdge.FromNodeIndex = mulSinFirstHalfOpIndex; + firstHalfSinMulToAddEdge.FromNodeOutputIndex = 0; + firstHalfSinMulToAddEdge.ToNodeIndex = addSecondHalfOpIndex; + firstHalfSinMulToAddEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(firstHalfSinMulToAddEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC firstAddToJoinEdge = {}; + firstAddToJoinEdge.FromNodeIndex = addFirstHalfOpIndex; + firstAddToJoinEdge.FromNodeOutputIndex = 0; + firstAddToJoinEdge.ToNodeIndex = joinOpIndex; + firstAddToJoinEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(firstAddToJoinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC secondAddToJoinEdge = {}; + secondAddToJoinEdge.FromNodeIndex = addSecondHalfOpIndex; + secondAddToJoinEdge.FromNodeOutputIndex = 0; + secondAddToJoinEdge.ToNodeIndex = joinOpIndex; + secondAddToJoinEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(secondAddToJoinEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC joinToOutputEdge = {}; + joinToOutputEdge.FromNodeIndex = joinOpIndex; + joinToOutputEdge.FromNodeOutputIndex = 0; + joinToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(joinToOutputEdge); + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(RotaryEmbedding, DmlOperatorRotaryEmbedding); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 30bc6e5e275a0..28360f09bcba3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -510,6 +510,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BitwiseAnd); DML_OP_EXTERN_CREATION_FUNCTION(BitwiseOr); DML_OP_EXTERN_CREATION_FUNCTION(BitwiseXor); DML_OP_EXTERN_CREATION_FUNCTION(BitwiseNot); +DML_OP_EXTERN_CREATION_FUNCTION(RotaryEmbedding); DML_OP_EXTERN_QUERY_FUNCTION(MaxPool); DML_OP_EXTERN_QUERY_FUNCTION(Slice); @@ -527,6 +528,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Attention); constexpr static std::array typeNameListDefault = {"T"}; constexpr static std::array typeNameListDefaultV = {"V"}; constexpr static std::array typeNameListAttention = {"T", "M"}; +constexpr static std::array typeNameListRotaryEmbedding = {"T", "M"}; constexpr static std::array typeNameListTwo = { "T1", "T2" }; constexpr static std::array typeNameListLayerNorm = { "T", "U" }; constexpr static std::array typeNameListLayerNormContrib = { "T", "V" }; @@ -597,6 +599,7 @@ constexpr static std::array supportedTypeListShape constexpr static std::array supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8}; constexpr static std::array supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32}; +constexpr static std::array supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32}; constexpr static std::array supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool}; @@ -1006,6 +1009,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)}, {REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)}, {REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)}, {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 485e20c1dfe1e..f7e545d9d99a9 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1584,6 +1584,7 @@ using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Attention = AttentionHelper; using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper; +using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_IsNaN = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Erf = GetBroadcastedOutputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index c1e525400be1a..e18ba31def48a 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -437,6 +437,7 @@ namespace OperatorHelper static const int sc_sinceVer_BiasAdd = 1; static const int sc_sinceVer_QuickGelu = 1; static const int sc_sinceVer_GroupNorm = 1; + static const int sc_sinceVer_RotaryEmbedding = 1; } // namespace MsftOperatorSet1 } // namespace OperatorHelper From 20ba78d7a30eb70ef5556e63ec1d1fd5d2c54727 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sun, 29 Oct 2023 22:23:43 -0700 Subject: [PATCH 2/8] Fix non-interleaved implementation --- .../Operators/DmlOperatorRotaryEmbedding.cpp | 121 ++++++++++++++---- .../contrib_ops/rotary_embedding_op_test.cc | 21 ++- 2 files changed, 113 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 0a8a1cf85e87c..6539cfbbb279d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -21,6 +21,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + // When positionIds is a scalar, it represents the start offset for each sequence + const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1; + Initialize(kernelInfo); ComPtr contextPrivate; @@ -50,11 +53,11 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // Split the input data into 2 equal parts const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - const std::array inputDataTensorShape {batchSize, sequenceLength, numHeads, headSize}; + const std::array inputDataTensorShape = {batchSize, sequenceLength, numHeads, headSize}; TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); - const std::array splitInputDataTensorShape {batchSize, sequenceLength, numHeads, headSize / 2}; + const std::array splitInputDataTensorShape = {batchSize, sequenceLength, numHeads, headSize / 2}; TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; @@ -65,22 +68,55 @@ class DmlOperatorRotaryEmbedding : public DmlOperator splitInputDesc.Axis = gsl::narrow_cast(splitInputDataTensorShape.size()) - 1; const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc}; + // We generate a sequence from 0 to sequenceLength and add the offset to it + const std::array positionIdsRangeShape = {1, 1, 1, sequenceLength}; + auto positionIdsDataType = kernelInfo.GetInputEdgeDescription(positionIdsIndex).tensorDataType; + TensorDesc positionIdsRangeTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, positionIdsRangeShape); + const DML_TENSOR_DESC positionIdsRangeDmlTensorDesc = positionIdsRangeTensorDesc.GetDmlDesc(); + + const std::array broadcastedPositionIdsRangeShape = {1, 1, batchSize, sequenceLength}; + TensorDesc broadcastedPositionIdsRangeTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedPositionIdsRangeShape, positionIdsRangeShape); + const DML_TENSOR_DESC broadcastedPositionIdsRangeDmlTensorDesc = broadcastedPositionIdsRangeTensorDesc.GetDmlDesc(); + + const std::array broadcastedOffsetShape = {1, 1, batchSize, sequenceLength}; + TensorDesc broadcastedOffsetTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedOffsetShape, m_inputTensorDescs[positionIdsIndex].GetSizes()); + const DML_TENSOR_DESC broadcastedOffsetDmlTensorDesc = broadcastedOffsetTensorDesc.GetDmlDesc(); + + TensorDesc offsetPositionIdsTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, broadcastedOffsetShape); + const DML_TENSOR_DESC offsetPositionIdsRangeDmlTensorDesc = offsetPositionIdsTensorDesc.GetDmlDesc(); + + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC positionIdsRange{}; + DML_ELEMENT_WISE_ADD_OPERATOR_DESC positionIdsAddOffset{}; + if (positionIdsIsOffset) + { + ML_CHECK_VALID_ARGUMENT(positionIdsDataType == MLOperatorTensorDataType::Int64); + positionIdsRange.ValueDataType = DML_TENSOR_DATA_TYPE_INT64; + positionIdsRange.ValueDelta.Int64 = 1; + positionIdsRange.OutputTensor = &positionIdsRangeDmlTensorDesc; + + positionIdsAddOffset.ATensor = &broadcastedPositionIdsRangeDmlTensorDesc; + positionIdsAddOffset.BTensor = &broadcastedOffsetDmlTensorDesc; + positionIdsAddOffset.OutputTensor = &offsetPositionIdsRangeDmlTensorDesc; + } + const DML_OPERATOR_DESC positionIdsRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &positionIdsRange}; + const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset}; + // Gather the cos/sin values based on the position ids - const std::array gatheredCosSinShape {1, 1, sequenceLength, headSize / 2}; + const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2}; TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape); const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc(); DML_GATHER_OPERATOR_DESC gatherCosSinDesc{}; gatherCosSinDesc.InputTensor = &inputDescs[cosCacheIndex]; - gatherCosSinDesc.IndicesTensor = &inputDescs[positionIdsIndex]; + gatherCosSinDesc.IndicesTensor = positionIdsIsOffset ? &offsetPositionIdsRangeDmlTensorDesc : &inputDescs[positionIdsIndex]; gatherCosSinDesc.OutputTensor = &gatheredCosSinDmlTensorDesc; gatherCosSinDesc.Axis = 2; gatherCosSinDesc.IndexDimensions = 2; const DML_OPERATOR_DESC gatherCosSinDmlDesc {DML_OPERATOR_GATHER, &gatherCosSinDesc}; // After gathering cos/sin, reshape and broadcast them to match the number of heads of the half input data - const std::array reshapedCosSinShape {1, sequenceLength, 1, headSize / 2}; - const std::array broadcastedCosSinShape {batchSize, sequenceLength, numHeads, headSize / 2}; + const std::array reshapedCosSinShape = {batchSize, sequenceLength, 1, headSize / 2}; + const std::array broadcastedCosSinShape = {batchSize, sequenceLength, numHeads, headSize / 2}; TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, broadcastedCosSinShape, reshapedCosSinShape); const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc(); @@ -89,20 +125,20 @@ class DmlOperatorRotaryEmbedding : public DmlOperator mulHalfDataDesc.ATensor = &splitInputDataDmlTensorDescs.front(); mulHalfDataDesc.BTensor = &broadcastedCosSinDmlTensorDesc; mulHalfDataDesc.OutputTensor = &splitInputDataDmlTensorDescs.front(); - const DML_OPERATOR_DESC mulHalfDataDmlDesc {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulHalfDataDesc}; + const DML_OPERATOR_DESC mulHalfDataDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulHalfDataDesc}; // Negate the second half of the data DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC negateHalfDataDesc{}; negateHalfDataDesc.InputTensor = &splitInputDataDmlTensorDescs.front(); negateHalfDataDesc.OutputTensor = &splitInputDataDmlTensorDescs.front(); - const DML_OPERATOR_DESC negateHalfDataDmlDesc {DML_OPERATOR_ELEMENT_WISE_NEGATE, &negateHalfDataDesc}; + const DML_OPERATOR_DESC negateHalfDataDmlDesc = {DML_OPERATOR_ELEMENT_WISE_NEGATE, &negateHalfDataDesc}; // Add the multiplied 2 halves together DML_ELEMENT_WISE_ADD_OPERATOR_DESC addHalfDataDesc{}; addHalfDataDesc.ATensor = &splitInputDataDmlTensorDescs.front(); addHalfDataDesc.BTensor = &splitInputDataDmlTensorDescs.front(); addHalfDataDesc.OutputTensor = &splitInputDataDmlTensorDescs.front(); - const DML_OPERATOR_DESC addHalfDataDmlDesc {DML_OPERATOR_ELEMENT_WISE_ADD, &addHalfDataDesc}; + const DML_OPERATOR_DESC addHalfDataDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addHalfDataDesc}; // Join the 2 halves together DML_JOIN_OPERATOR_DESC joinHalfDataDesc{}; @@ -110,14 +146,14 @@ class DmlOperatorRotaryEmbedding : public DmlOperator joinHalfDataDesc.OutputTensor = &inputDataDmlTensorDesc; joinHalfDataDesc.Axis = gsl::narrow_cast(splitInputDataTensorShape.size()) - 1; joinHalfDataDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); - const DML_OPERATOR_DESC joinHalfDataDmlDesc {DML_OPERATOR_JOIN, &joinHalfDataDesc}; + const DML_OPERATOR_DESC joinHalfDataDmlDesc = {DML_OPERATOR_JOIN, &joinHalfDataDesc}; // Construct the graph std::vector inputEdges; std::vector intermediateEdges; std::vector outputEdges; - std::array opDescs = { + std::vector opDescs = { &splitInputDmlDesc, // Split the input data &gatherCosSinDmlDesc, // Gather cos &gatherCosSinDmlDesc, // Gather sin @@ -150,8 +186,59 @@ class DmlOperatorRotaryEmbedding : public DmlOperator addSecondHalfOpIndex, joinOpIndex, + + // The following indices are optional + positionIdsRangeOpIndex, + positionIdsAddOffsetOpIndex, }; + if (positionIdsIsOffset) + { + opDescs.push_back(&positionIdsRangeDmlDesc); + opDescs.push_back(&positionIdsAddOffsetDmlDesc); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToAddOffsetEdge = {}; + positionIdsToAddOffsetEdge.GraphInputIndex = positionIdsIndex; + positionIdsToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsToAddOffsetEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToAddOffsetEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsOffsetToAddOffsetEdge = {}; + positionIdsOffsetToAddOffsetEdge.FromNodeIndex = positionIdsRangeOpIndex; + positionIdsOffsetToAddOffsetEdge.FromNodeOutputIndex = 0; + positionIdsOffsetToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsOffsetToAddOffsetEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(positionIdsOffsetToAddOffsetEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherCosEdge = {}; + positionIdsAddOffsetToGatherCosEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsAddOffsetToGatherCosEdge.FromNodeOutputIndex = 0; + positionIdsAddOffsetToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; + positionIdsAddOffsetToGatherCosEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(positionIdsAddOffsetToGatherCosEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherSinEdge = {}; + positionIdsAddOffsetToGatherSinEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsAddOffsetToGatherSinEdge.FromNodeOutputIndex = 0; + positionIdsAddOffsetToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; + positionIdsAddOffsetToGatherSinEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(positionIdsAddOffsetToGatherSinEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherCosEdge = {}; + positionIdsToGatherCosEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; + positionIdsToGatherCosEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherCosEdge); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherSinEdge = {}; + positionIdsToGatherSinEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; + positionIdsToGatherSinEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherSinEdge); + } + DML_INPUT_GRAPH_EDGE_DESC inputToSplitEdge = {}; inputToSplitEdge.GraphInputIndex = inputDataIndex; inputToSplitEdge.ToNodeIndex = splitInputOpIndex; @@ -164,24 +251,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator cosToGatherEdge.ToNodeInputIndex = 0; inputEdges.push_back(cosToGatherEdge); - DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherCosEdge = {}; - positionIdsToGatherCosEdge.GraphInputIndex = positionIdsIndex; - positionIdsToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; - positionIdsToGatherCosEdge.ToNodeInputIndex = 1; - inputEdges.push_back(positionIdsToGatherCosEdge); - DML_INPUT_GRAPH_EDGE_DESC sinToGatherEdge = {}; sinToGatherEdge.GraphInputIndex = sinCacheIndex; sinToGatherEdge.ToNodeIndex = gatherSinOpIndex; sinToGatherEdge.ToNodeInputIndex = 0; inputEdges.push_back(sinToGatherEdge); - DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherSinEdge = {}; - positionIdsToGatherSinEdge.GraphInputIndex = positionIdsIndex; - positionIdsToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; - positionIdsToGatherSinEdge.ToNodeInputIndex = 1; - inputEdges.push_back(positionIdsToGatherSinEdge); - DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToMulCosEdge = {}; firstHalfDataToMulCosEdge.FromNodeIndex = splitInputOpIndex; firstHalfDataToMulCosEdge.FromNodeOutputIndex = 0; diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 29d8219c162a5..e35e281fae4ac 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -25,7 +25,8 @@ static void RunTest( int64_t interleaved, bool use_float16, bool disable_cpu, - bool disable_cuda) { + bool disable_cuda, + bool disable_dml) { // input : (batch_size, sequence_length, hidden_size) // position ids : (1) or (batch_size, sequence_length) // cos cache : (max_sequence_length, head_size / 2) @@ -50,9 +51,14 @@ static void RunTest( int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + if (enable_cuda && !disable_cuda) { execution_providers.push_back(DefaultCudaExecutionProvider()); } + if (enable_dml && !disable_dml) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } if (!use_float16 && !disable_cpu) { execution_providers.push_back(DefaultCpuExecutionProvider()); } @@ -107,9 +113,10 @@ static void RunTests(const std::vector& input_data, interleaved, false, /* use_fp16 */ false, /* disable_cpu */ - true /* disable_cuda */); + true, /* disable_cuda */ + true /* disable_dml */); - // FP32 test for CUDA + // FP32 test for CUDA and DML RunTest(input_data, position_ids, cos_cache, @@ -123,9 +130,10 @@ static void RunTests(const std::vector& input_data, interleaved, false, /* use_fp16 */ false, /* disable_cpu */ - false /* disable_cuda */); + false, /* disable_cuda */ + false /* disable_dml */); - // FP16 test for CUDA + // FP16 test for CUDA and DML if (use_float16) { RunTest(input_data, position_ids, @@ -140,7 +148,8 @@ static void RunTests(const std::vector& input_data, interleaved, true, /* use_fp16 */ true, /* disable_cpu */ - false /* disable_cuda*/); + false, /* disable_cuda*/ + false /* disable_dml */); } } From e5246c31904d7bf4469995bcb5f62dc21b4e9725 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 30 Oct 2023 00:20:44 -0700 Subject: [PATCH 3/8] All tests passing --- .../Operators/DmlOperatorRotaryEmbedding.cpp | 367 +++++++++++++++++- .../dml/OperatorAuthorHelper/Attributes.h | 1 + 2 files changed, 354 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 6539cfbbb279d..e94dfd06635d8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -10,14 +10,6 @@ class DmlOperatorRotaryEmbedding : public DmlOperator public: DmlOperatorRotaryEmbedding(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) { - enum InputIndex : uint32_t - { - inputDataIndex, - positionIdsIndex, - cosCacheIndex, - sinCacheIndex, - }; - ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); @@ -26,9 +18,6 @@ class DmlOperatorRotaryEmbedding : public DmlOperator Initialize(kernelInfo); - ComPtr contextPrivate; - ORT_THROW_IF_FAILED(kernelInfo.GetInterface()->QueryInterface(contextPrivate.GetAddressOf())); - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetDimensionCount() == 4); ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[positionIdsIndex].GetDimensionCount() == 4); ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetDimensionCount() == 4); @@ -36,9 +25,6 @@ class DmlOperatorRotaryEmbedding : public DmlOperator ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4); - std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetSizes() == m_inputTensorDescs[sinCacheIndex].GetSizes()); const uint32_t headSize = m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2; @@ -51,6 +37,359 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const uint32_t sequenceLength = inputDataSizes[2]; const uint32_t numHeads = inputDataSizes[3] / headSize; + const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); + const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; + + if (sequenceLength > maxSequenceLength) + { + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + const bool interleaved = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::Interleaved, 0)); + + if (interleaved) + { + CreateInterleavedOperator(kernelInfo, positionIdsIsOffset, batchSize, sequenceLength, numHeads, headSize); + } + else + { + CreateNonInterleavedOperator(kernelInfo, positionIdsIsOffset, batchSize, sequenceLength, numHeads, headSize); + } + } + +private: + enum InputIndex : uint32_t + { + inputDataIndex, + positionIdsIndex, + cosCacheIndex, + sinCacheIndex, + }; + + void CreateInterleavedOperator(const MLOperatorKernelCreationContext& kernelInfo, bool positionIdsIsOffset, uint32_t batchSize, uint32_t sequenceLength, uint32_t numHeads, uint32_t headSize) + { + std::vector inputDescs = GetDmlInputDescs(); + + // Split the input data into 2 equal parts + const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; + const std::array inputDataTensorShape = {batchSize, sequenceLength, numHeads, headSize / 2, 2}; + TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); + + const std::array splitInputDataTensorShape = {batchSize, sequenceLength, numHeads, headSize / 2, 1}; + TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); + const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; + + DML_SPLIT_OPERATOR_DESC splitInputDesc{}; + splitInputDesc.InputTensor = &inputDataDmlTensorDesc; + splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); + splitInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + splitInputDesc.Axis = gsl::narrow_cast(splitInputDataTensorShape.size()) - 1; + const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc}; + + // Swap the 2 halves and join them together + DML_JOIN_OPERATOR_DESC joinInputDesc{}; + joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); + joinInputDesc.OutputTensor = &inputDataDmlTensorDesc; + joinInputDesc.Axis = gsl::narrow_cast(splitInputDataTensorShape.size()) - 1; + joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; + + // We generate a sequence from 0 to sequenceLength and add the offset to it + const std::array positionIdsRangeShape = {1, 1, 1, sequenceLength}; + auto positionIdsDataType = kernelInfo.GetInputEdgeDescription(positionIdsIndex).tensorDataType; + TensorDesc positionIdsRangeTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, positionIdsRangeShape); + const DML_TENSOR_DESC positionIdsRangeDmlTensorDesc = positionIdsRangeTensorDesc.GetDmlDesc(); + + const std::array broadcastedPositionIdsRangeShape = {1, 1, batchSize, sequenceLength}; + TensorDesc broadcastedPositionIdsRangeTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedPositionIdsRangeShape, positionIdsRangeShape); + const DML_TENSOR_DESC broadcastedPositionIdsRangeDmlTensorDesc = broadcastedPositionIdsRangeTensorDesc.GetDmlDesc(); + + const std::array broadcastedOffsetShape = {1, 1, batchSize, sequenceLength}; + TensorDesc broadcastedOffsetTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedOffsetShape, m_inputTensorDescs[positionIdsIndex].GetSizes()); + const DML_TENSOR_DESC broadcastedOffsetDmlTensorDesc = broadcastedOffsetTensorDesc.GetDmlDesc(); + + TensorDesc offsetPositionIdsTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, broadcastedOffsetShape); + const DML_TENSOR_DESC offsetPositionIdsRangeDmlTensorDesc = offsetPositionIdsTensorDesc.GetDmlDesc(); + + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC positionIdsRange{}; + DML_ELEMENT_WISE_ADD_OPERATOR_DESC positionIdsAddOffset{}; + if (positionIdsIsOffset) + { + ML_CHECK_VALID_ARGUMENT(positionIdsDataType == MLOperatorTensorDataType::Int64); + positionIdsRange.ValueDataType = DML_TENSOR_DATA_TYPE_INT64; + positionIdsRange.ValueDelta.Int64 = 1; + positionIdsRange.OutputTensor = &positionIdsRangeDmlTensorDesc; + + positionIdsAddOffset.ATensor = &broadcastedPositionIdsRangeDmlTensorDesc; + positionIdsAddOffset.BTensor = &broadcastedOffsetDmlTensorDesc; + positionIdsAddOffset.OutputTensor = &offsetPositionIdsRangeDmlTensorDesc; + } + const DML_OPERATOR_DESC positionIdsRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &positionIdsRange}; + const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset}; + + // Gather the cos/sin values based on the position ids + const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2}; + TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape); + const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc(); + + DML_GATHER_OPERATOR_DESC gatherCosSinDesc{}; + gatherCosSinDesc.InputTensor = &inputDescs[cosCacheIndex]; + gatherCosSinDesc.IndicesTensor = positionIdsIsOffset ? &offsetPositionIdsRangeDmlTensorDesc : &inputDescs[positionIdsIndex]; + gatherCosSinDesc.OutputTensor = &gatheredCosSinDmlTensorDesc; + gatherCosSinDesc.Axis = 2; + gatherCosSinDesc.IndexDimensions = 2; + const DML_OPERATOR_DESC gatherCosSinDmlDesc {DML_OPERATOR_GATHER, &gatherCosSinDesc}; + + // After gathering cos/sin, reshape and broadcast them to match the number of heads of the half input data + const std::array reshapedCosSinShape = {batchSize, sequenceLength, 1, headSize / 2, 1}; + const std::array broadcastedCosSinShape = {batchSize, sequenceLength, numHeads, headSize / 2, 2}; + TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, broadcastedCosSinShape, reshapedCosSinShape); + const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc(); + + // Create a vector that contains the sign values {-1, 1} + const std::array signTensorShape = {2}; + TensorDesc signTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, signTensorShape); + const DML_TENSOR_DESC signDmlTensorDesc = signTensorDesc.GetDmlDesc(); + + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC signRange{}; + signRange.OutputTensor = &signDmlTensorDesc; + if (dataType == MLOperatorTensorDataType::Float16) + { + const auto valueStart = static_cast(-1.0f); + const auto valueDelta = static_cast(2.0f); + memcpy(signRange.ValueStart.Bytes, reinterpret_cast(&valueStart), sizeof(valueStart)); + memcpy(signRange.ValueDelta.Bytes, reinterpret_cast(&valueDelta), sizeof(valueDelta)); + signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT16; + } + else + { + ML_CHECK_VALID_ARGUMENT(dataType == MLOperatorTensorDataType::Float); + signRange.ValueStart.Float32 = -1.0f; + signRange.ValueDelta.Float32 = 2.0f; + signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT32; + } + const DML_OPERATOR_DESC signRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &signRange}; + + // Multiply the broadcasted sign values with the rotated input + const std::array reshapedSignShape = {1, 1, 1, 1, 2}; + const std::array broadcastedsignShape = {batchSize, sequenceLength, numHeads, headSize / 2, 2}; + TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, broadcastedsignShape, reshapedSignShape); + const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); + + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; + mulSignDesc.ATensor = &inputDataDmlTensorDesc; + mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc; + mulSignDesc.OutputTensor = &inputDataDmlTensorDesc; + const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc}; + + // Multiply the non-rotated data with the cos and the rotated data with the sin + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{}; + mulCosSinDesc.ATensor = &inputDataDmlTensorDesc; + mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc; + mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc; + const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc}; + + // Add the multiplied cos and sin values together + DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; + addDesc.ATensor = &inputDataDmlTensorDesc; + addDesc.BTensor = &inputDataDmlTensorDesc; + addDesc.OutputTensor = &inputDataDmlTensorDesc; + const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; + + // Construct the graph + std::vector inputEdges; + std::vector intermediateEdges; + std::vector outputEdges; + + std::vector opDescs = { + &splitInputDmlDesc, // Split the input data + &gatherCosSinDmlDesc, // Gather cos + &gatherCosSinDmlDesc, // Gather sin + &signRangeDmlDesc, // Generate the signs + + &joinInputDmlDesc, // Join the split data + &mulCosSinDmlDesc, // Multiply cos with the non-rotated data + &mulCosSinDmlDesc, // Multiply sin with the rotated data + &mulSignDmlDesc, // Multiply the sign with the rotated data + &addDmlDesc, // Add the rotated cos and non-rotated sin parts together + }; + + enum NodeIndex : uint32_t + { + splitInputOpIndex, + gatherCosOpIndex, + gatherSinOpIndex, + signRangeOpIndex, + + joinInputOpIndex, + mulCosOpIndex, + mulSinOpIndex, + mulSignOpIndex, + addOpIndex, + + // The following indices are optional + positionIdsRangeOpIndex, + positionIdsAddOffsetOpIndex, + }; + + if (positionIdsIsOffset) + { + opDescs.push_back(&positionIdsRangeDmlDesc); + opDescs.push_back(&positionIdsAddOffsetDmlDesc); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToAddOffsetEdge = {}; + positionIdsToAddOffsetEdge.GraphInputIndex = positionIdsIndex; + positionIdsToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsToAddOffsetEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToAddOffsetEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsOffsetToAddOffsetEdge = {}; + positionIdsOffsetToAddOffsetEdge.FromNodeIndex = positionIdsRangeOpIndex; + positionIdsOffsetToAddOffsetEdge.FromNodeOutputIndex = 0; + positionIdsOffsetToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsOffsetToAddOffsetEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(positionIdsOffsetToAddOffsetEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherCosEdge = {}; + positionIdsAddOffsetToGatherCosEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsAddOffsetToGatherCosEdge.FromNodeOutputIndex = 0; + positionIdsAddOffsetToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; + positionIdsAddOffsetToGatherCosEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(positionIdsAddOffsetToGatherCosEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherSinEdge = {}; + positionIdsAddOffsetToGatherSinEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsAddOffsetToGatherSinEdge.FromNodeOutputIndex = 0; + positionIdsAddOffsetToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; + positionIdsAddOffsetToGatherSinEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(positionIdsAddOffsetToGatherSinEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherCosEdge = {}; + positionIdsToGatherCosEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; + positionIdsToGatherCosEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherCosEdge); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherSinEdge = {}; + positionIdsToGatherSinEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; + positionIdsToGatherSinEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherSinEdge); + } + + DML_INPUT_GRAPH_EDGE_DESC inputToSplitEdge = {}; + inputToSplitEdge.GraphInputIndex = inputDataIndex; + inputToSplitEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToSplitEdge); + + DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {}; + cosToGatherEdge.GraphInputIndex = cosCacheIndex; + cosToGatherEdge.ToNodeIndex = gatherCosOpIndex; + cosToGatherEdge.ToNodeInputIndex = 0; + inputEdges.push_back(cosToGatherEdge); + + DML_INPUT_GRAPH_EDGE_DESC sinToGatherEdge = {}; + sinToGatherEdge.GraphInputIndex = sinCacheIndex; + sinToGatherEdge.ToNodeIndex = gatherSinOpIndex; + sinToGatherEdge.ToNodeInputIndex = 0; + inputEdges.push_back(sinToGatherEdge); + + DML_INPUT_GRAPH_EDGE_DESC nonRotatedDataToMulEdge = {}; + nonRotatedDataToMulEdge.GraphInputIndex = inputDataIndex; + nonRotatedDataToMulEdge.ToNodeIndex = mulCosOpIndex; + nonRotatedDataToMulEdge.ToNodeInputIndex = 0; + inputEdges.push_back(nonRotatedDataToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToJoinEdge = {}; + secondHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + secondHalfDataToJoinEdge.FromNodeOutputIndex = 1; + secondHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + secondHalfDataToJoinEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(secondHalfDataToJoinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToJoinEdge = {}; + firstHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + firstHalfDataToJoinEdge.FromNodeOutputIndex = 0; + firstHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + firstHalfDataToJoinEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(firstHalfDataToJoinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC cosToMulEdge = {}; + cosToMulEdge.FromNodeIndex = gatherCosOpIndex; + cosToMulEdge.FromNodeOutputIndex = 0; + cosToMulEdge.ToNodeIndex = mulCosOpIndex; + cosToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(cosToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedDataToMulEdge = {}; + rotatedDataToMulEdge.FromNodeIndex = joinInputOpIndex; + rotatedDataToMulEdge.FromNodeOutputIndex = 0; + rotatedDataToMulEdge.ToNodeIndex = mulSinOpIndex; + rotatedDataToMulEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(rotatedDataToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC sinToMulEdge = {}; + sinToMulEdge.FromNodeIndex = gatherSinOpIndex; + sinToMulEdge.FromNodeOutputIndex = 0; + sinToMulEdge.ToNodeIndex = mulSinOpIndex; + sinToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(sinToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToMulEdge = {}; + rotatedSinToMulEdge.FromNodeIndex = mulSinOpIndex; + rotatedSinToMulEdge.FromNodeOutputIndex = 0; + rotatedSinToMulEdge.ToNodeIndex = mulSignOpIndex; + rotatedSinToMulEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(rotatedSinToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC signToMulEdge = {}; + signToMulEdge.FromNodeIndex = signRangeOpIndex; + signToMulEdge.FromNodeOutputIndex = 0; + signToMulEdge.ToNodeIndex = mulSignOpIndex; + signToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(signToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedCosToAddEdge = {}; + nonRotatedCosToAddEdge.FromNodeIndex = mulCosOpIndex; + nonRotatedCosToAddEdge.FromNodeOutputIndex = 0; + nonRotatedCosToAddEdge.ToNodeIndex = addOpIndex; + nonRotatedCosToAddEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(nonRotatedCosToAddEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToAddEdge = {}; + rotatedSinToAddEdge.FromNodeIndex = mulSignOpIndex; + rotatedSinToAddEdge.FromNodeOutputIndex = 0; + rotatedSinToAddEdge.ToNodeIndex = addOpIndex; + rotatedSinToAddEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(rotatedSinToAddEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {}; + addToOutputEdge.FromNodeIndex = addOpIndex; + addToOutputEdge.FromNodeOutputIndex = 0; + addToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(addToOutputEdge); + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); + } + + void CreateNonInterleavedOperator(const MLOperatorKernelCreationContext& kernelInfo, bool positionIdsIsOffset, uint32_t batchSize, uint32_t sequenceLength, uint32_t numHeads, uint32_t headSize) + { + std::vector inputDescs = GetDmlInputDescs(); + // Split the input data into 2 equal parts const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; const std::array inputDataTensorShape = {batchSize, sequenceLength, numHeads, headSize}; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index dac128f92ae0c..e9591cfce6870 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -122,6 +122,7 @@ namespace AttrName static constexpr const char* GraphFusedActivation = "activation"; static constexpr const char* GraphFusedAxis = "activation_axis"; + static constexpr const char* Interleaved = "interleaved"; } // namespace AttrName From 10e61f0b8320b3d98b3eadc3085164fc182f63c6 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 30 Oct 2023 00:59:07 -0700 Subject: [PATCH 4/8] Refactor implementation --- .../Operators/DmlOperatorRotaryEmbedding.cpp | 420 +++--------------- 1 file changed, 53 insertions(+), 367 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index e94dfd06635d8..44c29860e016c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -3,6 +3,28 @@ #include "precomp.h" +// This operator is easier to understand by looking at a python implementation of the non-interleaved version: +// +// def rotate_half(x): +// """Rotates half the hidden dims of the input.""" +// half_dim = x.shape[-1] // 2 +// x1 = x[..., :half_dim] +// x2 = x[..., half_dim:] +// return np.concatenate((-x2, x1), dim=-1) +// +// +// def apply_rope(x, cos, sin, position_ids): +// cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] +// sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] +// x_embed = (x * cos) + (rotate_half(x) * sin) +// return x_embed +// +// For the non-interleaved version, we multiply the cos cache by the non-rotated input tensor while we multiply the sin cache +// by the rotated input tensor. Rotating the tensor means slicing it in half on the head dimension and swapping the 2 halves. +// +// The interleaved version is very similar but instead of swapping 2 halves, we swap every pair of adjacent elements and we swap +// the sign of every adjacent element. + namespace Dml { class DmlOperatorRotaryEmbedding : public DmlOperator @@ -10,6 +32,14 @@ class DmlOperatorRotaryEmbedding : public DmlOperator public: DmlOperatorRotaryEmbedding(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) { + enum InputIndex : uint32_t + { + inputDataIndex, + positionIdsIndex, + cosCacheIndex, + sinCacheIndex, + }; + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); @@ -47,36 +77,21 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const bool interleaved = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::Interleaved, 0)); - if (interleaved) - { - CreateInterleavedOperator(kernelInfo, positionIdsIsOffset, batchSize, sequenceLength, numHeads, headSize); - } - else - { - CreateNonInterleavedOperator(kernelInfo, positionIdsIsOffset, batchSize, sequenceLength, numHeads, headSize); - } - } - -private: - enum InputIndex : uint32_t - { - inputDataIndex, - positionIdsIndex, - cosCacheIndex, - sinCacheIndex, - }; - - void CreateInterleavedOperator(const MLOperatorKernelCreationContext& kernelInfo, bool positionIdsIsOffset, uint32_t batchSize, uint32_t sequenceLength, uint32_t numHeads, uint32_t headSize) - { std::vector inputDescs = GetDmlInputDescs(); + const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; // Split the input data into 2 equal parts - const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - const std::array inputDataTensorShape = {batchSize, sequenceLength, numHeads, headSize / 2, 2}; + const std::vector inputDataTensorShape = interleaved + ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 2}) + : std::vector({batchSize, sequenceLength, numHeads, 2, headSize / 2}); + + const std::vector splitInputDataTensorShape = interleaved + ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 1}) + : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); + TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); - const std::array splitInputDataTensorShape = {batchSize, sequenceLength, numHeads, headSize / 2, 1}; TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; @@ -84,14 +99,17 @@ class DmlOperatorRotaryEmbedding : public DmlOperator splitInputDesc.InputTensor = &inputDataDmlTensorDesc; splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); splitInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); - splitInputDesc.Axis = gsl::narrow_cast(splitInputDataTensorShape.size()) - 1; + splitInputDesc.Axis = interleaved + ? gsl::narrow_cast(splitInputDataTensorShape.size()) - 1 + : gsl::narrow_cast(splitInputDataTensorShape.size()) - 2; + const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc}; // Swap the 2 halves and join them together DML_JOIN_OPERATOR_DESC joinInputDesc{}; joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); joinInputDesc.OutputTensor = &inputDataDmlTensorDesc; - joinInputDesc.Axis = gsl::narrow_cast(splitInputDataTensorShape.size()) - 1; + joinInputDesc.Axis = splitInputDesc.Axis; joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; @@ -141,10 +159,11 @@ class DmlOperatorRotaryEmbedding : public DmlOperator gatherCosSinDesc.IndexDimensions = 2; const DML_OPERATOR_DESC gatherCosSinDmlDesc {DML_OPERATOR_GATHER, &gatherCosSinDesc}; - // After gathering cos/sin, reshape and broadcast them to match the number of heads of the half input data - const std::array reshapedCosSinShape = {batchSize, sequenceLength, 1, headSize / 2, 1}; - const std::array broadcastedCosSinShape = {batchSize, sequenceLength, numHeads, headSize / 2, 2}; - TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, broadcastedCosSinShape, reshapedCosSinShape); + // After gathering cos/sin, reshape and broadcast them to match the number of heads of the input data + const std::vector reshapedCosSinShape = interleaved + ? std::vector({batchSize, sequenceLength, 1, headSize / 2, 1}) + : std::vector({batchSize, sequenceLength, 1, 1, headSize / 2}); + TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedCosSinShape); const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc(); // Create a vector that contains the sign values {-1, 1} @@ -172,9 +191,10 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const DML_OPERATOR_DESC signRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &signRange}; // Multiply the broadcasted sign values with the rotated input - const std::array reshapedSignShape = {1, 1, 1, 1, 2}; - const std::array broadcastedsignShape = {batchSize, sequenceLength, numHeads, headSize / 2, 2}; - TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, broadcastedsignShape, reshapedSignShape); + const std::vector reshapedSignShape = interleaved + ? std::vector({1, 1, 1, 1, 2}) + : std::vector({1, 1, 1, 2, 1}); + TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedSignShape); const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; @@ -385,340 +405,6 @@ class DmlOperatorRotaryEmbedding : public DmlOperator SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); } - - void CreateNonInterleavedOperator(const MLOperatorKernelCreationContext& kernelInfo, bool positionIdsIsOffset, uint32_t batchSize, uint32_t sequenceLength, uint32_t numHeads, uint32_t headSize) - { - std::vector inputDescs = GetDmlInputDescs(); - - // Split the input data into 2 equal parts - const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - const std::array inputDataTensorShape = {batchSize, sequenceLength, numHeads, headSize}; - TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); - const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); - - const std::array splitInputDataTensorShape = {batchSize, sequenceLength, numHeads, headSize / 2}; - TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); - const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; - - DML_SPLIT_OPERATOR_DESC splitInputDesc{}; - splitInputDesc.InputTensor = &inputDataDmlTensorDesc; - splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); - splitInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); - splitInputDesc.Axis = gsl::narrow_cast(splitInputDataTensorShape.size()) - 1; - const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc}; - - // We generate a sequence from 0 to sequenceLength and add the offset to it - const std::array positionIdsRangeShape = {1, 1, 1, sequenceLength}; - auto positionIdsDataType = kernelInfo.GetInputEdgeDescription(positionIdsIndex).tensorDataType; - TensorDesc positionIdsRangeTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, positionIdsRangeShape); - const DML_TENSOR_DESC positionIdsRangeDmlTensorDesc = positionIdsRangeTensorDesc.GetDmlDesc(); - - const std::array broadcastedPositionIdsRangeShape = {1, 1, batchSize, sequenceLength}; - TensorDesc broadcastedPositionIdsRangeTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedPositionIdsRangeShape, positionIdsRangeShape); - const DML_TENSOR_DESC broadcastedPositionIdsRangeDmlTensorDesc = broadcastedPositionIdsRangeTensorDesc.GetDmlDesc(); - - const std::array broadcastedOffsetShape = {1, 1, batchSize, sequenceLength}; - TensorDesc broadcastedOffsetTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedOffsetShape, m_inputTensorDescs[positionIdsIndex].GetSizes()); - const DML_TENSOR_DESC broadcastedOffsetDmlTensorDesc = broadcastedOffsetTensorDesc.GetDmlDesc(); - - TensorDesc offsetPositionIdsTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, broadcastedOffsetShape); - const DML_TENSOR_DESC offsetPositionIdsRangeDmlTensorDesc = offsetPositionIdsTensorDesc.GetDmlDesc(); - - DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC positionIdsRange{}; - DML_ELEMENT_WISE_ADD_OPERATOR_DESC positionIdsAddOffset{}; - if (positionIdsIsOffset) - { - ML_CHECK_VALID_ARGUMENT(positionIdsDataType == MLOperatorTensorDataType::Int64); - positionIdsRange.ValueDataType = DML_TENSOR_DATA_TYPE_INT64; - positionIdsRange.ValueDelta.Int64 = 1; - positionIdsRange.OutputTensor = &positionIdsRangeDmlTensorDesc; - - positionIdsAddOffset.ATensor = &broadcastedPositionIdsRangeDmlTensorDesc; - positionIdsAddOffset.BTensor = &broadcastedOffsetDmlTensorDesc; - positionIdsAddOffset.OutputTensor = &offsetPositionIdsRangeDmlTensorDesc; - } - const DML_OPERATOR_DESC positionIdsRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &positionIdsRange}; - const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset}; - - // Gather the cos/sin values based on the position ids - const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2}; - TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape); - const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc(); - - DML_GATHER_OPERATOR_DESC gatherCosSinDesc{}; - gatherCosSinDesc.InputTensor = &inputDescs[cosCacheIndex]; - gatherCosSinDesc.IndicesTensor = positionIdsIsOffset ? &offsetPositionIdsRangeDmlTensorDesc : &inputDescs[positionIdsIndex]; - gatherCosSinDesc.OutputTensor = &gatheredCosSinDmlTensorDesc; - gatherCosSinDesc.Axis = 2; - gatherCosSinDesc.IndexDimensions = 2; - const DML_OPERATOR_DESC gatherCosSinDmlDesc {DML_OPERATOR_GATHER, &gatherCosSinDesc}; - - // After gathering cos/sin, reshape and broadcast them to match the number of heads of the half input data - const std::array reshapedCosSinShape = {batchSize, sequenceLength, 1, headSize / 2}; - const std::array broadcastedCosSinShape = {batchSize, sequenceLength, numHeads, headSize / 2}; - TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, broadcastedCosSinShape, reshapedCosSinShape); - const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc(); - - // Multiply the first half of the data with the cos and the second half of the negated data with the sin - DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulHalfDataDesc{}; - mulHalfDataDesc.ATensor = &splitInputDataDmlTensorDescs.front(); - mulHalfDataDesc.BTensor = &broadcastedCosSinDmlTensorDesc; - mulHalfDataDesc.OutputTensor = &splitInputDataDmlTensorDescs.front(); - const DML_OPERATOR_DESC mulHalfDataDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulHalfDataDesc}; - - // Negate the second half of the data - DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC negateHalfDataDesc{}; - negateHalfDataDesc.InputTensor = &splitInputDataDmlTensorDescs.front(); - negateHalfDataDesc.OutputTensor = &splitInputDataDmlTensorDescs.front(); - const DML_OPERATOR_DESC negateHalfDataDmlDesc = {DML_OPERATOR_ELEMENT_WISE_NEGATE, &negateHalfDataDesc}; - - // Add the multiplied 2 halves together - DML_ELEMENT_WISE_ADD_OPERATOR_DESC addHalfDataDesc{}; - addHalfDataDesc.ATensor = &splitInputDataDmlTensorDescs.front(); - addHalfDataDesc.BTensor = &splitInputDataDmlTensorDescs.front(); - addHalfDataDesc.OutputTensor = &splitInputDataDmlTensorDescs.front(); - const DML_OPERATOR_DESC addHalfDataDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addHalfDataDesc}; - - // Join the 2 halves together - DML_JOIN_OPERATOR_DESC joinHalfDataDesc{}; - joinHalfDataDesc.InputTensors = splitInputDataDmlTensorDescs.data(); - joinHalfDataDesc.OutputTensor = &inputDataDmlTensorDesc; - joinHalfDataDesc.Axis = gsl::narrow_cast(splitInputDataTensorShape.size()) - 1; - joinHalfDataDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); - const DML_OPERATOR_DESC joinHalfDataDmlDesc = {DML_OPERATOR_JOIN, &joinHalfDataDesc}; - - // Construct the graph - std::vector inputEdges; - std::vector intermediateEdges; - std::vector outputEdges; - - std::vector opDescs = { - &splitInputDmlDesc, // Split the input data - &gatherCosSinDmlDesc, // Gather cos - &gatherCosSinDmlDesc, // Gather sin - - &mulHalfDataDmlDesc, // Multiply cos with the first half of the input - &negateHalfDataDmlDesc, // Negate the second half of the input - &mulHalfDataDmlDesc, // Multiply sin with the negated second of half of the input - &addHalfDataDmlDesc, // Add the 2 halves together - - &mulHalfDataDmlDesc, // Multiply cos with the second half of the input - &mulHalfDataDmlDesc, // Multiply sin with the first half of the input - &addHalfDataDmlDesc, // Add the 2 halves together - - &joinHalfDataDmlDesc, // Join the halves together - }; - - enum NodeIndex : uint32_t - { - splitInputOpIndex, - gatherCosOpIndex, - gatherSinOpIndex, - - mulCosFirstHalfOpIndex, - negateSecondHalfOpIndex, - mulSinNegatedSecondHalfOpIndex, - addFirstHalfOpIndex, - - mulCosSecondHalfOpIndex, - mulSinFirstHalfOpIndex, - addSecondHalfOpIndex, - - joinOpIndex, - - // The following indices are optional - positionIdsRangeOpIndex, - positionIdsAddOffsetOpIndex, - }; - - if (positionIdsIsOffset) - { - opDescs.push_back(&positionIdsRangeDmlDesc); - opDescs.push_back(&positionIdsAddOffsetDmlDesc); - - DML_INPUT_GRAPH_EDGE_DESC positionIdsToAddOffsetEdge = {}; - positionIdsToAddOffsetEdge.GraphInputIndex = positionIdsIndex; - positionIdsToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; - positionIdsToAddOffsetEdge.ToNodeInputIndex = 1; - inputEdges.push_back(positionIdsToAddOffsetEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsOffsetToAddOffsetEdge = {}; - positionIdsOffsetToAddOffsetEdge.FromNodeIndex = positionIdsRangeOpIndex; - positionIdsOffsetToAddOffsetEdge.FromNodeOutputIndex = 0; - positionIdsOffsetToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; - positionIdsOffsetToAddOffsetEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(positionIdsOffsetToAddOffsetEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherCosEdge = {}; - positionIdsAddOffsetToGatherCosEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; - positionIdsAddOffsetToGatherCosEdge.FromNodeOutputIndex = 0; - positionIdsAddOffsetToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; - positionIdsAddOffsetToGatherCosEdge.ToNodeInputIndex = 1; - intermediateEdges.push_back(positionIdsAddOffsetToGatherCosEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherSinEdge = {}; - positionIdsAddOffsetToGatherSinEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; - positionIdsAddOffsetToGatherSinEdge.FromNodeOutputIndex = 0; - positionIdsAddOffsetToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; - positionIdsAddOffsetToGatherSinEdge.ToNodeInputIndex = 1; - intermediateEdges.push_back(positionIdsAddOffsetToGatherSinEdge); - } - else - { - DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherCosEdge = {}; - positionIdsToGatherCosEdge.GraphInputIndex = positionIdsIndex; - positionIdsToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; - positionIdsToGatherCosEdge.ToNodeInputIndex = 1; - inputEdges.push_back(positionIdsToGatherCosEdge); - - DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherSinEdge = {}; - positionIdsToGatherSinEdge.GraphInputIndex = positionIdsIndex; - positionIdsToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; - positionIdsToGatherSinEdge.ToNodeInputIndex = 1; - inputEdges.push_back(positionIdsToGatherSinEdge); - } - - DML_INPUT_GRAPH_EDGE_DESC inputToSplitEdge = {}; - inputToSplitEdge.GraphInputIndex = inputDataIndex; - inputToSplitEdge.ToNodeIndex = splitInputOpIndex; - inputToSplitEdge.ToNodeInputIndex = 0; - inputEdges.push_back(inputToSplitEdge); - - DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {}; - cosToGatherEdge.GraphInputIndex = cosCacheIndex; - cosToGatherEdge.ToNodeIndex = gatherCosOpIndex; - cosToGatherEdge.ToNodeInputIndex = 0; - inputEdges.push_back(cosToGatherEdge); - - DML_INPUT_GRAPH_EDGE_DESC sinToGatherEdge = {}; - sinToGatherEdge.GraphInputIndex = sinCacheIndex; - sinToGatherEdge.ToNodeIndex = gatherSinOpIndex; - sinToGatherEdge.ToNodeInputIndex = 0; - inputEdges.push_back(sinToGatherEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToMulCosEdge = {}; - firstHalfDataToMulCosEdge.FromNodeIndex = splitInputOpIndex; - firstHalfDataToMulCosEdge.FromNodeOutputIndex = 0; - firstHalfDataToMulCosEdge.ToNodeIndex = mulCosFirstHalfOpIndex; - firstHalfDataToMulCosEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(firstHalfDataToMulCosEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC gatheredCosToMulFirstHalfDataEdge = {}; - gatheredCosToMulFirstHalfDataEdge.FromNodeIndex = gatherCosOpIndex; - gatheredCosToMulFirstHalfDataEdge.FromNodeOutputIndex = 0; - gatheredCosToMulFirstHalfDataEdge.ToNodeIndex = mulCosFirstHalfOpIndex; - gatheredCosToMulFirstHalfDataEdge.ToNodeInputIndex = 1; - intermediateEdges.push_back(gatheredCosToMulFirstHalfDataEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToNegateEdge = {}; - secondHalfDataToNegateEdge.FromNodeIndex = splitInputOpIndex; - secondHalfDataToNegateEdge.FromNodeOutputIndex = 1; - secondHalfDataToNegateEdge.ToNodeIndex = negateSecondHalfOpIndex; - secondHalfDataToNegateEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(secondHalfDataToNegateEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC negatedSecondHalfDataToMulSinEdge = {}; - negatedSecondHalfDataToMulSinEdge.FromNodeIndex = negateSecondHalfOpIndex; - negatedSecondHalfDataToMulSinEdge.FromNodeOutputIndex = 0; - negatedSecondHalfDataToMulSinEdge.ToNodeIndex = mulSinNegatedSecondHalfOpIndex; - negatedSecondHalfDataToMulSinEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(negatedSecondHalfDataToMulSinEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC gatheredSinToMulNegatedSecondHalfDataEdge = {}; - gatheredSinToMulNegatedSecondHalfDataEdge.FromNodeIndex = gatherSinOpIndex; - gatheredSinToMulNegatedSecondHalfDataEdge.FromNodeOutputIndex = 0; - gatheredSinToMulNegatedSecondHalfDataEdge.ToNodeIndex = mulSinNegatedSecondHalfOpIndex; - gatheredSinToMulNegatedSecondHalfDataEdge.ToNodeInputIndex = 1; - intermediateEdges.push_back(gatheredSinToMulNegatedSecondHalfDataEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfCosMulToAddEdge = {}; - firstHalfCosMulToAddEdge.FromNodeIndex = mulCosFirstHalfOpIndex; - firstHalfCosMulToAddEdge.FromNodeOutputIndex = 0; - firstHalfCosMulToAddEdge.ToNodeIndex = addFirstHalfOpIndex; - firstHalfCosMulToAddEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(firstHalfCosMulToAddEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfSinMulToAddEdge = {}; - secondHalfSinMulToAddEdge.FromNodeIndex = mulSinNegatedSecondHalfOpIndex; - secondHalfSinMulToAddEdge.FromNodeOutputIndex = 0; - secondHalfSinMulToAddEdge.ToNodeIndex = addFirstHalfOpIndex; - secondHalfSinMulToAddEdge.ToNodeInputIndex = 1; - intermediateEdges.push_back(secondHalfSinMulToAddEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToMulCosEdge = {}; - secondHalfDataToMulCosEdge.FromNodeIndex = splitInputOpIndex; - secondHalfDataToMulCosEdge.FromNodeOutputIndex = 1; - secondHalfDataToMulCosEdge.ToNodeIndex = mulCosSecondHalfOpIndex; - secondHalfDataToMulCosEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(secondHalfDataToMulCosEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC gatheredCosToMulSecondHalfDataEdge = {}; - gatheredCosToMulSecondHalfDataEdge.FromNodeIndex = gatherCosOpIndex; - gatheredCosToMulSecondHalfDataEdge.FromNodeOutputIndex = 0; - gatheredCosToMulSecondHalfDataEdge.ToNodeIndex = mulCosSecondHalfOpIndex; - gatheredCosToMulSecondHalfDataEdge.ToNodeInputIndex = 1; - intermediateEdges.push_back(gatheredCosToMulSecondHalfDataEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToMulSinEdge = {}; - firstHalfDataToMulSinEdge.FromNodeIndex = splitInputOpIndex; - firstHalfDataToMulSinEdge.FromNodeOutputIndex = 0; - firstHalfDataToMulSinEdge.ToNodeIndex = mulSinFirstHalfOpIndex; - firstHalfDataToMulSinEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(firstHalfDataToMulSinEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC gatheredSinToMulFirstHalfDataEdge = {}; - gatheredSinToMulFirstHalfDataEdge.FromNodeIndex = gatherSinOpIndex; - gatheredSinToMulFirstHalfDataEdge.FromNodeOutputIndex = 0; - gatheredSinToMulFirstHalfDataEdge.ToNodeIndex = mulSinFirstHalfOpIndex; - gatheredSinToMulFirstHalfDataEdge.ToNodeInputIndex = 1; - intermediateEdges.push_back(gatheredSinToMulFirstHalfDataEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfCosMulToAddEdge = {}; - secondHalfCosMulToAddEdge.FromNodeIndex = mulCosSecondHalfOpIndex; - secondHalfCosMulToAddEdge.FromNodeOutputIndex = 0; - secondHalfCosMulToAddEdge.ToNodeIndex = addSecondHalfOpIndex; - secondHalfCosMulToAddEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(secondHalfCosMulToAddEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfSinMulToAddEdge = {}; - firstHalfSinMulToAddEdge.FromNodeIndex = mulSinFirstHalfOpIndex; - firstHalfSinMulToAddEdge.FromNodeOutputIndex = 0; - firstHalfSinMulToAddEdge.ToNodeIndex = addSecondHalfOpIndex; - firstHalfSinMulToAddEdge.ToNodeInputIndex = 1; - intermediateEdges.push_back(firstHalfSinMulToAddEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC firstAddToJoinEdge = {}; - firstAddToJoinEdge.FromNodeIndex = addFirstHalfOpIndex; - firstAddToJoinEdge.FromNodeOutputIndex = 0; - firstAddToJoinEdge.ToNodeIndex = joinOpIndex; - firstAddToJoinEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(firstAddToJoinEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC secondAddToJoinEdge = {}; - secondAddToJoinEdge.FromNodeIndex = addSecondHalfOpIndex; - secondAddToJoinEdge.FromNodeOutputIndex = 0; - secondAddToJoinEdge.ToNodeIndex = joinOpIndex; - secondAddToJoinEdge.ToNodeInputIndex = 1; - intermediateEdges.push_back(secondAddToJoinEdge); - - DML_OUTPUT_GRAPH_EDGE_DESC joinToOutputEdge = {}; - joinToOutputEdge.FromNodeIndex = joinOpIndex; - joinToOutputEdge.FromNodeOutputIndex = 0; - joinToOutputEdge.GraphOutputIndex = 0; - outputEdges.push_back(joinToOutputEdge); - - MLOperatorGraphDesc operatorGraphDesc = {}; - operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); - operatorGraphDesc.inputEdges = inputEdges.data(); - operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); - operatorGraphDesc.intermediateEdges = intermediateEdges.data(); - operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); - operatorGraphDesc.outputEdges = outputEdges.data(); - operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); - - SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); - } }; DML_OP_DEFINE_CREATION_FUNCTION(RotaryEmbedding, DmlOperatorRotaryEmbedding); From eb5b540d8a7c7a4e7c3e92e96b560bee527d010f Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 30 Oct 2023 01:10:57 -0700 Subject: [PATCH 5/8] Apply lintrunner --- onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index e35e281fae4ac..55f01bf0d3f1d 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -113,7 +113,7 @@ static void RunTests(const std::vector& input_data, interleaved, false, /* use_fp16 */ false, /* disable_cpu */ - true, /* disable_cuda */ + true, /* disable_cuda */ true /* disable_dml */); // FP32 test for CUDA and DML @@ -146,8 +146,8 @@ static void RunTests(const std::vector& input_data, num_heads, max_sequence_length, interleaved, - true, /* use_fp16 */ - true, /* disable_cpu */ + true, /* use_fp16 */ + true, /* disable_cpu */ false, /* disable_cuda*/ false /* disable_dml */); } From 33b3952554a609a44a3bc4eb43dee31524905a0f Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 30 Oct 2023 13:07:56 -0700 Subject: [PATCH 6/8] Update contrib ops docs --- docs/ContribOperators.md | 1 + docs/OperatorKernels.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 890403556cc47..0d2b472b6a779 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2582,6 +2582,7 @@ This version of the operator has been available since version 1 of the 'com.micr Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + #### Version This version of the operator has been available since version 1 of the 'com.microsoft' operator set. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bfb7716dc5cea..0ff42449894f6 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1246,6 +1246,7 @@ Do not modify directly.* |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | From baf62c5de1e57fcfc85d3ae5feb2100015d63dd0 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 3 Nov 2023 12:10:48 -0700 Subject: [PATCH 7/8] Reshape output --- .../src/Operators/DmlOperatorRotaryEmbedding.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 44c29860e016c..234fddcb820c2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -78,6 +78,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const bool interleaved = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::Interleaved, 0)); std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; // Split the input data into 2 equal parts @@ -212,9 +213,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // Add the multiplied cos and sin values together DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; - addDesc.ATensor = &inputDataDmlTensorDesc; - addDesc.BTensor = &inputDataDmlTensorDesc; - addDesc.OutputTensor = &inputDataDmlTensorDesc; + addDesc.ATensor = &outputDescs[0]; + addDesc.BTensor = &outputDescs[0]; + addDesc.OutputTensor = &outputDescs[0]; const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; // Construct the graph From 8fc37171f518f646797de0d682e3da005d883d93 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 3 Nov 2023 22:27:27 -0700 Subject: [PATCH 8/8] Add copy node to keep the input shape intact --- .../Operators/DmlOperatorRotaryEmbedding.cpp | 47 ++++++++++++++----- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 234fddcb820c2..30c339b845b36 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -78,9 +78,22 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const bool interleaved = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::Interleaved, 0)); std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; + // Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle + const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; + TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + + // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. + DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; + + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; + copyInputDesc.InputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.ScaleBias = &scaleBias; + const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; + // Split the input data into 2 equal parts const std::vector inputDataTensorShape = interleaved ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 2}) @@ -213,9 +226,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // Add the multiplied cos and sin values together DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; - addDesc.ATensor = &outputDescs[0]; - addDesc.BTensor = &outputDescs[0]; - addDesc.OutputTensor = &outputDescs[0]; + addDesc.ATensor = &inputOutputDmlTensorDesc; + addDesc.BTensor = &inputOutputDmlTensorDesc; + addDesc.OutputTensor = &inputOutputDmlTensorDesc; const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; // Construct the graph @@ -224,6 +237,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector outputEdges; std::vector opDescs = { + ©InputDmlDesc, // Copy the input data to preseve the real input shape &splitInputDmlDesc, // Split the input data &gatherCosSinDmlDesc, // Gather cos &gatherCosSinDmlDesc, // Gather sin @@ -238,6 +252,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator enum NodeIndex : uint32_t { + copyInputOpIndex, splitInputOpIndex, gatherCosOpIndex, gatherSinOpIndex, @@ -301,11 +316,11 @@ class DmlOperatorRotaryEmbedding : public DmlOperator inputEdges.push_back(positionIdsToGatherSinEdge); } - DML_INPUT_GRAPH_EDGE_DESC inputToSplitEdge = {}; - inputToSplitEdge.GraphInputIndex = inputDataIndex; - inputToSplitEdge.ToNodeIndex = splitInputOpIndex; - inputToSplitEdge.ToNodeInputIndex = 0; - inputEdges.push_back(inputToSplitEdge); + DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {}; + inputToCopyInputEdge.GraphInputIndex = inputDataIndex; + inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; + inputToCopyInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToCopyInputEdge); DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {}; cosToGatherEdge.GraphInputIndex = cosCacheIndex; @@ -319,11 +334,19 @@ class DmlOperatorRotaryEmbedding : public DmlOperator sinToGatherEdge.ToNodeInputIndex = 0; inputEdges.push_back(sinToGatherEdge); - DML_INPUT_GRAPH_EDGE_DESC nonRotatedDataToMulEdge = {}; - nonRotatedDataToMulEdge.GraphInputIndex = inputDataIndex; + DML_INTERMEDIATE_GRAPH_EDGE_DESC inputToSplitEdge = {}; + inputToSplitEdge.FromNodeIndex = copyInputOpIndex; + inputToSplitEdge.FromNodeOutputIndex = 0; + inputToSplitEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(inputToSplitEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedDataToMulEdge = {}; + nonRotatedDataToMulEdge.FromNodeIndex = copyInputOpIndex; + nonRotatedDataToMulEdge.FromNodeOutputIndex = 0; nonRotatedDataToMulEdge.ToNodeIndex = mulCosOpIndex; nonRotatedDataToMulEdge.ToNodeInputIndex = 0; - inputEdges.push_back(nonRotatedDataToMulEdge); + intermediateEdges.push_back(nonRotatedDataToMulEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToJoinEdge = {}; secondHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex;