Skip to content

Commit

Permalink
[DML EP] Support split hidden size for RotaryEmbedding (#18852)
Browse files Browse the repository at this point in the history
RotaryEmbedding now supports the `[batchSize, numHeads, sequenceLength,
headSize]` format for its input, which is used in Mistral.
  • Loading branch information
PatriceVignola authored Feb 13, 2024
1 parent a622710 commit 61e07a4
Showing 1 changed file with 26 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);

// When the input is 4D, it has the shape [batchSize, numHeads, sequenceLength, headSize]. Otherwise,
// it has the shape [batchSize, sequenceLength, hiddenSize]
const bool inputIs4D = kernelInfo.GetInputTensorDimensionCount(inputDataIndex) == 4;

// When positionIds is a scalar, it represents the start offset for each sequence
const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1;

Expand All @@ -63,9 +67,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator

// We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize]
const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes();
const uint32_t batchSize = inputDataSizes[1];
const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1];
const uint32_t sequenceLength = inputDataSizes[2];
const uint32_t numHeads = inputDataSizes[3] / headSize;
const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize;

const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes();
const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2];
Expand All @@ -80,16 +84,24 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType;

// Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle
const std::array<uint32_t, 4> inputOutputShape = {batchSize, sequenceLength, numHeads, headSize};
TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);
TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);

if (inputIs4D)
{
const std::array<uint32_t, 4> inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1};
stridedInputOutputTensorDesc.SetStrides(inputOutputStrides);
}

const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc();
const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc();

// Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase.
DML_SCALE_BIAS scaleBias = {1.0f, 0.0f};

DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{};
copyInputDesc.InputTensor = &inputOutputDmlTensorDesc;
copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc;
copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc;
copyInputDesc.ScaleBias = &scaleBias;
const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &copyInputDesc};
Expand All @@ -104,8 +116,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
: std::vector<uint32_t>({batchSize, sequenceLength, numHeads, 1, headSize / 2});

TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);

const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc();

TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);
const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc();

TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape);
const std::array<DML_TENSOR_DESC, 2> splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()};

Expand All @@ -122,7 +138,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
// Swap the 2 halves and join them together
DML_JOIN_OPERATOR_DESC joinInputDesc{};
joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data();
joinInputDesc.OutputTensor = &inputDataDmlTensorDesc;
joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc;
joinInputDesc.Axis = splitInputDesc.Axis;
joinInputDesc.InputCount = gsl::narrow_cast<uint32_t>(splitInputDataDmlTensorDescs.size());
const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc};
Expand Down Expand Up @@ -212,23 +228,23 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc();

DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{};
mulSignDesc.ATensor = &inputDataDmlTensorDesc;
mulSignDesc.ATensor = &joinedDataDmlTensorDesc;
mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc;
mulSignDesc.OutputTensor = &inputDataDmlTensorDesc;
mulSignDesc.OutputTensor = &joinedDataDmlTensorDesc;
const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc};

// Multiply the non-rotated data with the cos and the rotated data with the sin
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{};
mulCosSinDesc.ATensor = &inputDataDmlTensorDesc;
mulCosSinDesc.ATensor = &joinedDataDmlTensorDesc;
mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc;
mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc;
mulCosSinDesc.OutputTensor = &joinedDataDmlTensorDesc;
const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc};

// Add the multiplied cos and sin values together
DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{};
addDesc.ATensor = &inputOutputDmlTensorDesc;
addDesc.BTensor = &inputOutputDmlTensorDesc;
addDesc.OutputTensor = &inputOutputDmlTensorDesc;
addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc;
const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc};

// Construct the graph
Expand Down

0 comments on commit 61e07a4

Please sign in to comment.