Skip to content

Commit

Permalink
[DML EP] Enable more MHA masks (#17882)
Browse files Browse the repository at this point in the history
Those masks are used for MHA in LLaMA.
  • Loading branch information
PatriceVignola authored Oct 18, 2023
1 parent 2ef7abf commit 6557538
Showing 1 changed file with 26 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,34 @@ class DmlOperatorMultiHeadAttention : public DmlOperator
else
{
const auto keyPaddingMaskTensorShape = m_inputTensorDescs[dmlMaskIndex].GetSizes();
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape.size() == 2);
size_t maskDimCount = keyPaddingMaskTensorShape.size();
ML_CHECK_VALID_ARGUMENT(maskDimCount >= 2 || maskDimCount <= 4);
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[0] == batchSize);
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength);

const uint32_t actualShape[4] = {batchSize, 1, 1, kvSequenceLength};
const uint32_t desiredShape[4] = {batchSize, numHeads, sequenceLength, kvSequenceLength};
std::array<uint32_t, 4> actualShape{};
std::array<uint32_t, 4> desiredShape{};

if (maskDimCount == 2)
{
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength);
actualShape = {batchSize, 1, 1, kvSequenceLength};
desiredShape = {batchSize, numHeads, sequenceLength, kvSequenceLength};
}
else if (maskDimCount == 3)
{
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == sequenceLength);
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == totalSequenceLength);
actualShape = {batchSize, 1, sequenceLength, totalSequenceLength};
desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength};
}
else if (maskDimCount == 4)
{
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == numHeads);
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == sequenceLength);
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[3] == totalSequenceLength);
actualShape = {batchSize, numHeads, sequenceLength, totalSequenceLength};
desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength};
}

m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
m_inputTensorDescs[dmlMaskIndex].GetMlOperatorDataType(),
Expand Down

0 comments on commit 6557538

Please sign in to comment.