Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
raoanag committed Feb 28, 2024
1 parent 471081e commit 6a904a7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ class DmlOperatorQAttention : public DmlOperator
const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc};
const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc};

// Causal Mask: [pastSequenceLength, pastSequenceLength + 1 ... pastSequenceLength + batchSize -1]
// Causal Mask: Upper Triangular Boolean Matrix
// DML adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0
// passed to MHA as maskIndex Tensor when unidirectional == 1
std::array<uint32_t, 4> causalMaskOutputShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength};
TensorDesc causalMaskTensorDesc;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,7 @@ void TestQuantizedAttentionPastState(int64_t batch,

std::vector<int64_t> past_dims{2, batch, head_number, past_seq_len, head_size};
std::vector<float> past_data = random.Gaussian<float>(past_dims, 0.0f, 0.3f);

OpTester test("QAttention", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("num_heads", head_number);
test.AddAttribute<int64_t>("unidirectional", 1);
Expand Down

0 comments on commit 6a904a7

Please sign in to comment.