Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DML EP] Support split hidden size for RotaryEmbedding #18852

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);

// When the input is 4D, it has the shape [batchSize, numHeads, sequenceLength, headSize]. Otherwise,
// it has the shape [batchSize, sequenceLength, hiddenSize]
const bool inputIs4D = kernelInfo.GetInputTensorDimensionCount(inputDataIndex) == 4;

// When positionIds is a scalar, it represents the start offset for each sequence
const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1;

Expand All @@ -63,9 +67,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator

// We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize]
const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes();
const uint32_t batchSize = inputDataSizes[1];
const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1];
const uint32_t sequenceLength = inputDataSizes[2];
const uint32_t numHeads = inputDataSizes[3] / headSize;
const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize;

const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes();
const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2];
Expand All @@ -80,16 +84,24 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType;

// Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle
const std::array<uint32_t, 4> inputOutputShape = {batchSize, sequenceLength, numHeads, headSize};
TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);
TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);

if (inputIs4D)
{
const std::array<uint32_t, 4> inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1};
stridedInputOutputTensorDesc.SetStrides(inputOutputStrides);
}

const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc();
const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc();

// Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase.
DML_SCALE_BIAS scaleBias = {1.0f, 0.0f};

DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{};
copyInputDesc.InputTensor = &inputOutputDmlTensorDesc;
copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc;
copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc;
copyInputDesc.ScaleBias = &scaleBias;
const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &copyInputDesc};
Expand All @@ -104,8 +116,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
: std::vector<uint32_t>({batchSize, sequenceLength, numHeads, 1, headSize / 2});

TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);

const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc();

TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);
const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc();

TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape);
const std::array<DML_TENSOR_DESC, 2> splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()};

Expand All @@ -122,7 +138,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
// Swap the 2 halves and join them together
DML_JOIN_OPERATOR_DESC joinInputDesc{};
joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data();
joinInputDesc.OutputTensor = &inputDataDmlTensorDesc;
joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc;
joinInputDesc.Axis = splitInputDesc.Axis;
joinInputDesc.InputCount = gsl::narrow_cast<uint32_t>(splitInputDataDmlTensorDescs.size());
const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc};
Expand Down Expand Up @@ -212,23 +228,23 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc();

DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{};
mulSignDesc.ATensor = &inputDataDmlTensorDesc;
mulSignDesc.ATensor = &joinedDataDmlTensorDesc;
mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc;
mulSignDesc.OutputTensor = &inputDataDmlTensorDesc;
mulSignDesc.OutputTensor = &joinedDataDmlTensorDesc;
const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc};

// Multiply the non-rotated data with the cos and the rotated data with the sin
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{};
mulCosSinDesc.ATensor = &inputDataDmlTensorDesc;
mulCosSinDesc.ATensor = &joinedDataDmlTensorDesc;
mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc;
mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc;
mulCosSinDesc.OutputTensor = &joinedDataDmlTensorDesc;
const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc};

// Add the multiplied cos and sin values together
DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{};
addDesc.ATensor = &inputOutputDmlTensorDesc;
addDesc.BTensor = &inputOutputDmlTensorDesc;
addDesc.OutputTensor = &inputOutputDmlTensorDesc;
addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc;
const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc};

// Construct the graph
Expand Down
Loading