Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DMLEP QAttention update causal #19533

Merged
merged 6 commits into from
Mar 4, 2024
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -361,25 +361,30 @@ class DmlOperatorQAttention : public DmlOperator
const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc};
const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc};

// Causal Mask: [pastSequenceLength, pastSequenceLength + 1 ... pastSequenceLength + batchSize -1]
// Causal Mask: Upper Triangular Boolean Matrix
raoanag marked this conversation as resolved.
Show resolved Hide resolved
// Example: [[1, 0, 0, 0, 0],
// [1, 1, 0, 0, 0],
// [1, 1, 1, 0, 0],
// [1, 1, 1, 1, 0]]
// DML adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0
// passed to MHA as maskIndex Tensor when unidirectional == 1
std::array<uint32_t, 2> causalMaskOutputShape = {1, batchSize};
std::array<uint32_t, 4> causalMaskOutputShape = {1, 1, sequenceLength, pastSequenceLength + sequenceLength};
TensorDesc causalMaskTensorDesc;
DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC causalMaskOperatorDesc = {};
DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {};
DML_TENSOR_DESC namedcausalMaskTensorDesc;

if (unidirectional && !hasMask)
{
causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape);
namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc();
causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32;
causalMaskOperatorDesc.ValueStart.Int32 = pastSequenceLength;
causalMaskOperatorDesc.ValueDelta.Int32 = 1;
causalMaskOperatorDesc.DiagonalFillBegin = INT32_MIN;
raoanag marked this conversation as resolved.
Show resolved Hide resolved
causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1;
causalMaskOperatorDesc.Value.Int32 = 1;
causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc;

maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH;
maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN;
}
DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_FILL_VALUE_SEQUENCE, &causalMaskOperatorDesc };
DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc };

DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {};
std::array<uint32_t, 5> presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
Expand All @@ -393,7 +398,11 @@ class DmlOperatorQAttention : public DmlOperator

if (unidirectional && !hasMask)
{
mhaOperatorDesc.MaskTensor = &namedcausalMaskTensorDesc;
// Broadcast to MHA MaskTensor Shape
std::array<uint32_t, 4> mhaMaskTensorShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength};
TensorDesc broadcastedcausalMaskTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(MLOperatorTensorDataType::Int32, mhaMaskTensorShape, causalMaskOutputShape);
const DML_TENSOR_DESC namedbroadcastedcausalMaskTensorDesc = broadcastedcausalMaskTensorDesc.GetDmlDesc();
mhaOperatorDesc.MaskTensor = &namedbroadcastedcausalMaskTensorDesc;
}
else if (hasMaxSequenceMask)
{
Expand All @@ -407,7 +416,7 @@ class DmlOperatorQAttention : public DmlOperator
mhaOperatorDesc.RelativePositionBiasTensor = nullptr;
mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex];
mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Scale, gsl::narrow_cast<float>(1.0f / std::sqrt(headSize)));
mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, -10'000.0f);
mhaOperatorDesc.MaskFilterValue = std::numeric_limits<float>::lowest();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should still give precedence to the user provided value and we should just change the default value to lowest().
mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, std::numeric_limits<float>::lowest());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the CPU and CUDA EPs change to lowest() for the default? The contrib ops age still mentions -10000 as the default https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QAttention

Copy link
Contributor Author

@raoanag raoanag Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For unidirectional that does not seem to be the case as per the reference, https://github.com/microsoft/onnxruntime/blob/d5606cd7ee394ba9444ef509021720ebe63c9856/onnxruntime/contrib_ops/cpu/bert/attention_helper.h#L142C1-L149C6

I will add std::numeric_limits<float>::lowest() only for unidirectional set.

mhaOperatorDesc.HeadCount = numHeads;
mhaOperatorDesc.MaskType = maskType;
if (hasPast)
Expand Down
Loading