Skip to content

Commit

Permalink
Fix MHA when past/key is empty
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Jan 27, 2024
1 parent e9c073e commit d3e0ecc
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class DmlOperatorMultiHeadAttention : public DmlOperator
const bool hasBias = kernelCreationContext.IsInputValid(biasIndex);
const bool hasMask = kernelCreationContext.IsInputValid(maskIndex);
const bool hasRelativePositionBias = kernelCreationContext.IsInputValid(relativePositionBiasIndex);
const bool hasPastKey = keyValueIsPast || kernelCreationContext.IsInputValid(pastKeyIndex);
const bool hasPastValue = keyValueIsPast || kernelCreationContext.IsInputValid(pastValueIndex);
const bool hasPastKey = keyValueIsPast || (kernelCreationContext.IsInputValid(pastKeyIndex) && kernelCreationContext.GetInputTensorShape(pastKeyIndex)[2] != 0);
const bool hasPastValue = keyValueIsPast || (kernelCreationContext.IsInputValid(pastValueIndex) && kernelCreationContext.GetInputTensorShape(pastValueIndex)[2] != 0);
const bool hasPresentKeyOutput = kernelCreationContext.IsOutputValid(outputPresentKeyIndex);
const bool hasPresentValueOutput = kernelCreationContext.IsOutputValid(outputPresentValueIndex);
const bool stackedQkv = kernelCreationContext.GetInputTensorDimensionCount(queryIndex) == 5;
Expand All @@ -74,8 +74,8 @@ class DmlOperatorMultiHeadAttention : public DmlOperator
biasIndex,
hasMask ? std::optional<uint32_t>(maskIndex) : std::nullopt,
relativePositionBiasIndex,
keyValueIsPast ? keyIndex : pastKeyIndex,
keyValueIsPast ? valueIndex : pastValueIndex,
hasPastKey ? std::optional<uint32_t>(keyValueIsPast ? keyIndex : pastKeyIndex) : std::nullopt,
hasPastValue ? std::optional<uint32_t>(keyValueIsPast ? valueIndex : pastValueIndex) : std::nullopt,
};

std::vector<std::optional<uint32_t>> outputIndices = {
Expand Down

0 comments on commit d3e0ecc

Please sign in to comment.