Skip to content

Commit

Permalink
Fix striding
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Jan 18, 2024
1 parent 28af091 commit f7db464
Showing 1 changed file with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,24 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType;

// We can collapse this size into a 1D tensor since it's only used as the input/output of elementwise operations
const std::array<uint32_t, 1> inputOutputShape = {batchSize * sequenceLength * numHeads * headSize};
const std::array<uint32_t, 4> inputOutputShape = {batchSize, sequenceLength, numHeads, headSize};
TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);
TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape);

if (inputIs4D)
{
const std::array<uint32_t, 4> inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1};
stridedInputOutputTensorDesc.SetStrides(inputOutputStrides);
}

const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc();
const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc();

// Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase.
DML_SCALE_BIAS scaleBias = {1.0f, 0.0f};

DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{};
copyInputDesc.InputTensor = &inputOutputDmlTensorDesc;
copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc;
copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc;
copyInputDesc.ScaleBias = &scaleBias;
const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &copyInputDesc};
Expand All @@ -109,14 +117,6 @@ class DmlOperatorRotaryEmbedding : public DmlOperator

TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);

// We need to stride it if the input was 4D
if (inputIs4D)
{
const std::vector<uint32_t> inputDataStrides = interleaved
? std::vector<uint32_t>({sequenceLength * numHeads * headSize, headSize, sequenceLength * headSize, 2, 1})
: std::vector<uint32_t>({sequenceLength * numHeads * headSize, headSize, sequenceLength * headSize, headSize / 2, 1});
}

const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc();

TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape);
Expand Down Expand Up @@ -244,7 +244,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator
DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{};
addDesc.ATensor = &inputOutputDmlTensorDesc;
addDesc.BTensor = &inputOutputDmlTensorDesc;
addDesc.OutputTensor = &inputOutputDmlTensorDesc;
addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc;
const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc};

// Construct the graph
Expand Down

0 comments on commit f7db464

Please sign in to comment.