diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 0f15ebf342b3a..95d9644b4ca30 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -25,6 +25,41 @@ // The interleaved version is very similar but instead of swapping 2 halves, we swap every pair of adjacent elements and we swap // the sign of every adjacent element. +// Here's a representation of what the graph looks like in DML, before getting fused together: +/* + Input CosCache PositionIds SinCache + | | | | + | | +--------+-----------+ | + Split | | | | + | | Gather Gather + +-------+ | | | + | | | | + | Identity----------+ | | + | | | | | + | | | | | + | --Split-- | | | + | \ / | +-----------------+ | + | \ / | | | + | \ / Mul | + | \ / | | + | X | | + | / \ | | + | / \ | | + | Join | | + | | | | + | | +---------------------------------------------------------+ + | | | | + | Mul | + | | | + | +-----+ +------+ + | | | + | Add + | | + +-------------+ | + | | + Join +*/ + namespace Dml { class DmlOperatorRotaryEmbedding : public DmlOperator @@ -56,25 +91,45 @@ class DmlOperatorRotaryEmbedding : public DmlOperator ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[positionIdsIndex].GetDimensionCount() == 4); ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetDimensionCount() == 4); ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[sinCacheIndex].GetDimensionCount() == 4); - ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4); - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetSizes() == m_inputTensorDescs[sinCacheIndex].GetSizes()); - const uint32_t headSize = m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2; - // The last dimension of the data is the hidden size, so it must be divisible by the head size - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetSizes().back() % headSize == 0); + uint32_t numHeads = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::NumHeads, 0)); + uint32_t rotaryEmbeddingDim = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::RotaryEmbeddingDim, 0)); - // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); + const uint32_t hiddenSize = inputIs4D ? inputDataSizes[1] * inputDataSizes[3] : inputDataSizes.back(); + + const uint32_t headSize = numHeads == 0 + ? m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2 + : hiddenSize / numHeads; + + if (rotaryEmbeddingDim > 0) + { + ORT_ENFORCE(numHeads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); + } + else + { + rotaryEmbeddingDim = headSize; + } + + if (numHeads == 0) + { + numHeads = hiddenSize / headSize; + } + else if (inputIs4D) + { + ORT_ENFORCE(numHeads == inputDataSizes[1], "When the input has 4 dimensions, num_heads must be 0 or have the same value as the second dimension of the input"); + } + const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1]; const uint32_t sequenceLength = inputDataSizes[2]; - const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize; const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; - if (sequenceLength > maxSequenceLength) + const bool isPackedBatching = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::IsPackedBatching, 0)) == 1; + if (!isPackedBatching && sequenceLength > maxSequenceLength) { ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); } @@ -84,64 +139,103 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector inputDescs = GetDmlInputDescs(); const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; + // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] + const std::array inputOutputShape = inputIs4D + ? std::array({batchSize, numHeads, sequenceLength, headSize}) + : std::array({batchSize, sequenceLength, numHeads, headSize}); + + const std::array splitInputOutputShape1 = inputIs4D + ? std::array({batchSize, numHeads, sequenceLength, rotaryEmbeddingDim}) + : std::array({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim}); + + const std::array splitInputOutputShape2 = inputIs4D + ? std::array({batchSize, numHeads, sequenceLength, headSize - rotaryEmbeddingDim}) + : std::array({batchSize, sequenceLength, numHeads, headSize - rotaryEmbeddingDim}); + TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); - TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + TensorDesc splitInputOutputTensorDesc1 = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputOutputShape1); + TensorDesc splitInputOutputTensorDesc2 = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputOutputShape2); - if (inputIs4D) + // Split the input to perform the rotary embedding only on a subregion of the tensor if needed. The split inputs + // will be joined back together at the end. + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + + std::array splitTensorDescs = { + splitInputOutputTensorDesc1.GetDmlDesc(), + splitInputOutputTensorDesc2.GetDmlDesc(), + }; + + DML_SPLIT_OPERATOR_DESC splitInputOperatorDesc{}; + DML_OPERATOR_DESC splitInputDmlOperatorDesc{}; + if (headSize != rotaryEmbeddingDim) { - const std::array inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1}; - stridedInputOutputTensorDesc.SetStrides(inputOutputStrides); + splitInputOperatorDesc.InputTensor = &inputOutputDmlTensorDesc; + splitInputOperatorDesc.OutputCount = gsl::narrow_cast(splitTensorDescs.size()); + splitInputOperatorDesc.OutputTensors = splitTensorDescs.data(); + splitInputOperatorDesc.Axis = gsl::narrow_cast(inputOutputShape.size()) - 1; + splitInputDmlOperatorDesc.Type = DML_OPERATOR_SPLIT; + splitInputDmlOperatorDesc.Desc = &splitInputOperatorDesc; } - const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); - const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc(); - - // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. + // Copy the partial input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; + const std::array partialInputOutputShape = {batchSize, sequenceLength, numHeads, rotaryEmbeddingDim}; + TensorDesc partialStridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputOutputShape); + TensorDesc partialInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputOutputShape); + + if (inputIs4D) + { + const std::array partialInputOutputStrides = {rotaryEmbeddingDim * numHeads * sequenceLength, rotaryEmbeddingDim, sequenceLength * rotaryEmbeddingDim, 1}; + partialStridedInputOutputTensorDesc.SetStrides(partialInputOutputStrides); + } + + const DML_TENSOR_DESC partialStridedInputOutputDmlTensorDesc = partialStridedInputOutputTensorDesc.GetDmlDesc(); + const DML_TENSOR_DESC partialInputOutputDmlTensorDesc = partialInputOutputTensorDesc.GetDmlDesc(); + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; - copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc; - copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.InputTensor = &partialStridedInputOutputDmlTensorDesc; + copyInputDesc.OutputTensor = &partialInputOutputDmlTensorDesc; copyInputDesc.ScaleBias = &scaleBias; const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; + const uint32_t halfRoraryEmbeddingDim = rotaryEmbeddingDim / 2; + // Split the input data into 2 equal parts - const std::vector inputDataTensorShape = interleaved - ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 2}) - : std::vector({batchSize, sequenceLength, numHeads, 2, headSize / 2}); + const std::vector partialInputDataTensorShape = interleaved + ? std::vector({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim / 2, 2}) + : std::vector({batchSize, sequenceLength, numHeads, 2, rotaryEmbeddingDim / 2}); const std::vector splitInputDataTensorShape = interleaved - ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 1}) - : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); + ? std::vector({batchSize, sequenceLength, numHeads, rotaryEmbeddingDim / 2, 1}) + : std::vector({batchSize, sequenceLength, numHeads, 1, rotaryEmbeddingDim / 2}); - TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + TensorDesc partialInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputDataTensorShape); + const DML_TENSOR_DESC partialInputDataDmlTensorDesc = partialInputDataTensorDesc.GetDmlDesc(); - const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); - - TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, partialInputDataTensorShape); const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc(); TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; - DML_SPLIT_OPERATOR_DESC splitInputDesc{}; - splitInputDesc.InputTensor = &inputDataDmlTensorDesc; - splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); - splitInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); - splitInputDesc.Axis = interleaved + DML_SPLIT_OPERATOR_DESC splitPartialInputDesc{}; + splitPartialInputDesc.InputTensor = &partialInputDataDmlTensorDesc; + splitPartialInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); + splitPartialInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + splitPartialInputDesc.Axis = interleaved ? gsl::narrow_cast(splitInputDataTensorShape.size()) - 1 : gsl::narrow_cast(splitInputDataTensorShape.size()) - 2; - const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc}; + const DML_OPERATOR_DESC splitPartialInputDmlDesc = {DML_OPERATOR_SPLIT, &splitPartialInputDesc}; // Swap the 2 halves and join them together - DML_JOIN_OPERATOR_DESC joinInputDesc{}; - joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); - joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc; - joinInputDesc.Axis = splitInputDesc.Axis; - joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); - const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; + DML_JOIN_OPERATOR_DESC joinPartialInputDesc{}; + joinPartialInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); + joinPartialInputDesc.OutputTensor = &joinedDataDmlTensorDesc; + joinPartialInputDesc.Axis = splitPartialInputDesc.Axis; + joinPartialInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + const DML_OPERATOR_DESC joinPartialInputDmlDesc = {DML_OPERATOR_JOIN, &joinPartialInputDesc}; // We generate a sequence from 0 to sequenceLength and add the offset to it const std::array positionIdsRangeShape = {1, 1, 1, sequenceLength}; @@ -177,7 +271,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset}; // Gather the cos/sin values based on the position ids - const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2}; + const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, rotaryEmbeddingDim / 2}; TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape); const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc(); @@ -191,9 +285,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // After gathering cos/sin, reshape and broadcast them to match the number of heads of the input data const std::vector reshapedCosSinShape = interleaved - ? std::vector({batchSize, sequenceLength, 1, headSize / 2, 1}) - : std::vector({batchSize, sequenceLength, 1, 1, headSize / 2}); - TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedCosSinShape); + ? std::vector({batchSize, sequenceLength, 1, rotaryEmbeddingDim / 2, 1}) + : std::vector({batchSize, sequenceLength, 1, 1, rotaryEmbeddingDim / 2}); + TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, partialInputDataTensorShape, reshapedCosSinShape); const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc(); // Create a vector that contains the sign values {-1, 1} @@ -224,7 +318,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const std::vector reshapedSignShape = interleaved ? std::vector({1, 1, 1, 1, 2}) : std::vector({1, 1, 1, 2, 1}); - TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedSignShape); + TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, partialInputDataTensorShape, reshapedSignShape); const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; @@ -242,11 +336,23 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // Add the multiplied cos and sin values together DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; - addDesc.ATensor = &inputOutputDmlTensorDesc; - addDesc.BTensor = &inputOutputDmlTensorDesc; - addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc; + addDesc.ATensor = &partialInputOutputDmlTensorDesc; + addDesc.BTensor = &partialInputOutputDmlTensorDesc; + addDesc.OutputTensor = &partialStridedInputOutputDmlTensorDesc; const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; + DML_JOIN_OPERATOR_DESC joinOutputOperatorDesc{}; + DML_OPERATOR_DESC joinOutputDmlOperatorDesc{}; + if (headSize != rotaryEmbeddingDim) + { + joinOutputOperatorDesc.InputCount = gsl::narrow_cast(splitTensorDescs.size()); + joinOutputOperatorDesc.InputTensors = splitTensorDescs.data(); + joinOutputOperatorDesc.OutputTensor = &inputOutputDmlTensorDesc; + joinOutputOperatorDesc.Axis = gsl::narrow_cast(inputOutputShape.size()) - 1; + joinOutputDmlOperatorDesc.Type = DML_OPERATOR_JOIN; + joinOutputDmlOperatorDesc.Desc = &joinOutputOperatorDesc; + } + // Construct the graph std::vector inputEdges; std::vector intermediateEdges; @@ -254,12 +360,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector opDescs = { ©InputDmlDesc, // Copy the input data to preseve the real input shape - &splitInputDmlDesc, // Split the input data + &splitPartialInputDmlDesc, // Split the input data &gatherCosSinDmlDesc, // Gather cos &gatherCosSinDmlDesc, // Gather sin &signRangeDmlDesc, // Generate the signs - &joinInputDmlDesc, // Join the split data + &joinPartialInputDmlDesc, // Join the split data &mulCosSinDmlDesc, // Multiply cos with the non-rotated data &mulCosSinDmlDesc, // Multiply sin with the rotated data &mulSignDmlDesc, // Multiply the sign with the rotated data @@ -269,12 +375,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator enum NodeIndex : uint32_t { copyInputOpIndex, - splitInputOpIndex, + splitPartialInputOpIndex, gatherCosOpIndex, gatherSinOpIndex, signRangeOpIndex, - joinInputOpIndex, + joinPartialInputOpIndex, mulCosOpIndex, mulSinOpIndex, mulSignOpIndex, @@ -285,6 +391,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator positionIdsAddOffsetOpIndex, }; + uint32_t splitInputOpIndex = positionIdsIsOffset ? positionIdsAddOffsetOpIndex + 1 : addOpIndex + 1; + uint32_t joinOutputOpIndex = splitInputOpIndex + 1; + if (positionIdsIsOffset) { opDescs.push_back(&positionIdsRangeDmlDesc); @@ -332,11 +441,32 @@ class DmlOperatorRotaryEmbedding : public DmlOperator inputEdges.push_back(positionIdsToGatherSinEdge); } - DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {}; - inputToCopyInputEdge.GraphInputIndex = inputDataIndex; - inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; - inputToCopyInputEdge.ToNodeInputIndex = 0; - inputEdges.push_back(inputToCopyInputEdge); + if (splitInputDmlOperatorDesc.Desc) + { + opDescs.push_back(&splitInputDmlOperatorDesc); + opDescs.push_back(&joinOutputDmlOperatorDesc); + + DML_INPUT_GRAPH_EDGE_DESC inputToSplitInputEdge = {}; + inputToSplitInputEdge.GraphInputIndex = inputDataIndex; + inputToSplitInputEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToSplitInputEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC partialInputToCopyInputEdge = {}; + partialInputToCopyInputEdge.FromNodeIndex = splitInputOpIndex; + partialInputToCopyInputEdge.FromNodeOutputIndex = 0; + partialInputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; + partialInputToCopyInputEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(partialInputToCopyInputEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {}; + inputToCopyInputEdge.GraphInputIndex = inputDataIndex; + inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; + inputToCopyInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToCopyInputEdge); + } DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {}; cosToGatherEdge.GraphInputIndex = cosCacheIndex; @@ -353,7 +483,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator DML_INTERMEDIATE_GRAPH_EDGE_DESC inputToSplitEdge = {}; inputToSplitEdge.FromNodeIndex = copyInputOpIndex; inputToSplitEdge.FromNodeOutputIndex = 0; - inputToSplitEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitEdge.ToNodeIndex = splitPartialInputOpIndex; inputToSplitEdge.ToNodeInputIndex = 0; intermediateEdges.push_back(inputToSplitEdge); @@ -365,16 +495,16 @@ class DmlOperatorRotaryEmbedding : public DmlOperator intermediateEdges.push_back(nonRotatedDataToMulEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToJoinEdge = {}; - secondHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + secondHalfDataToJoinEdge.FromNodeIndex = splitPartialInputOpIndex; secondHalfDataToJoinEdge.FromNodeOutputIndex = 1; - secondHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + secondHalfDataToJoinEdge.ToNodeIndex = joinPartialInputOpIndex; secondHalfDataToJoinEdge.ToNodeInputIndex = 0; intermediateEdges.push_back(secondHalfDataToJoinEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToJoinEdge = {}; - firstHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + firstHalfDataToJoinEdge.FromNodeIndex = splitPartialInputOpIndex; firstHalfDataToJoinEdge.FromNodeOutputIndex = 0; - firstHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + firstHalfDataToJoinEdge.ToNodeIndex = joinPartialInputOpIndex; firstHalfDataToJoinEdge.ToNodeInputIndex = 1; intermediateEdges.push_back(firstHalfDataToJoinEdge); @@ -386,7 +516,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator intermediateEdges.push_back(cosToMulEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedDataToMulEdge = {}; - rotatedDataToMulEdge.FromNodeIndex = joinInputOpIndex; + rotatedDataToMulEdge.FromNodeIndex = joinPartialInputOpIndex; rotatedDataToMulEdge.FromNodeOutputIndex = 0; rotatedDataToMulEdge.ToNodeIndex = mulSinOpIndex; rotatedDataToMulEdge.ToNodeInputIndex = 0; @@ -427,11 +557,36 @@ class DmlOperatorRotaryEmbedding : public DmlOperator rotatedSinToAddEdge.ToNodeInputIndex = 1; intermediateEdges.push_back(rotatedSinToAddEdge); - DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {}; - addToOutputEdge.FromNodeIndex = addOpIndex; - addToOutputEdge.FromNodeOutputIndex = 0; - addToOutputEdge.GraphOutputIndex = 0; - outputEdges.push_back(addToOutputEdge); + if (splitInputDmlOperatorDesc.Desc) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC addToJoinOutputEdge = {}; + addToJoinOutputEdge.FromNodeIndex = addOpIndex; + addToJoinOutputEdge.FromNodeOutputIndex = 0; + addToJoinOutputEdge.ToNodeIndex = joinOutputOpIndex; + addToJoinOutputEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(addToJoinOutputEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC remainingInputToJoinOutputEdge = {}; + remainingInputToJoinOutputEdge.FromNodeIndex = splitInputOpIndex; + remainingInputToJoinOutputEdge.FromNodeOutputIndex = 1; + remainingInputToJoinOutputEdge.ToNodeIndex = joinOutputOpIndex; + remainingInputToJoinOutputEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(remainingInputToJoinOutputEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC joinOutputToOutputEdge = {}; + joinOutputToOutputEdge.FromNodeIndex = joinOutputOpIndex; + joinOutputToOutputEdge.FromNodeOutputIndex = 0; + joinOutputToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(joinOutputToOutputEdge); + } + else + { + DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {}; + addToOutputEdge.FromNodeIndex = addOpIndex; + addToOutputEdge.FromNodeOutputIndex = 0; + addToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(addToOutputEdge); + } MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 0c5739554b800..3d23fb6206479 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -130,6 +130,8 @@ namespace AttrName static constexpr const char* UppercaseN = "N"; static constexpr const char* UppercaseK = "K"; static constexpr const char* MatMulNBitsBlockSize = "block_size"; + static constexpr const char* RotaryEmbeddingDim = "rotary_embedding_dim"; + static constexpr const char* IsPackedBatching = "is_packed_batching"; } // namespace AttrName diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 89552da58b938..8675a997d29a1 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -135,8 +135,7 @@ static void RunTests(const std::vector& input_data, int max_sequence_length = 0, int64_t interleaved = 0, int64_t is_packed_batching = 0, - bool use_float16 = true, - bool disable_dml = false) { + bool use_float16 = true) { // FP32 test for CPU RunTest(input_data, position_ids, @@ -173,7 +172,7 @@ static void RunTests(const std::vector& input_data, TensorType::kFloat, false, /* disable_cpu */ false, /* disable_cuda */ - disable_dml || false /* disable_dml */); + false /* disable_dml */); // FP16 test for CUDA and DML if (use_float16) { @@ -193,7 +192,7 @@ static void RunTests(const std::vector& input_data, TensorType::kFloat16, true, /* disable_cpu */ false, /* disable_cuda*/ - disable_dml || false /* disable_dml */); + false /* disable_dml */); // RunTest(input_data, // position_ids, @@ -743,9 +742,8 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) { num_heads, max_sequence_length, interleaved, - 0, // is_packed_batching - true, /*use_fp16*/ - true /*disable_dml*/); + 0, // is_packed_batching + true /*use_fp16*/); } TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_Batching) { @@ -785,9 +783,8 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_B num_heads, max_sequence_length, interleaved, - 1, // is_packed_batching - true, /*use_fp16*/ - true /*disable_dml*/); + 1, // is_packed_batching + true /*use_fp16*/); } } // namespace test