Skip to content

Commit

Permalink
Update QAttention reference diagram
Browse files Browse the repository at this point in the history
  • Loading branch information
raoanag committed Mar 9, 2024
1 parent 575fb02 commit 81572aa
Showing 1 changed file with 38 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,44 @@
/*
Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
N is number of attention heads, H is head size, and W=N*H
M is mask_index tensor
M, A, B, C and P are Inputs
M A B C
| | | /
| MatMulIntToFloat
| / | \
| / | \
| / | \
| Slice Slice Slice
| | | |
| | | |
| Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while
| | | | // keeping the GEMM strides as NCHW to better target metacommands
| | | |
| | | | P
| | | | / \
| | | | / \
| | | | Slice Slice
| | | | | |
| | | | | |
| | | | | |
--------------------------MHA -----------
/ | \
/ | \
/ | \
/ | \
/ | \
/ | \
/ presentKey presentValue
/ \ /
/ \ /
/ \ /
/ Concat
/ |
Output1 Output2 (present)
Input, Weight, Bias, Mask Index and Past are Inputs
Mask Index/Causal Input Weight Bias
| \ | /
| \ | /
| \ | /
| MatMulIntToFloat
| / | \
| / | \
| / | \
| Slice Slice Slice
| | | |
| | | |
| Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while
| | | | // keeping the GEMM strides as NCHW to better target metacommands
| | | |
| | | | Past
| | | | / \
| | | | / \
| | | | Slice Slice
| | | | | |
| | | | | |
| | | | | |
--------------------------MHA -----------
/ | \
/ | \
/ | \
/ | \
/ | \
/ | \
/ presentKey presentValue
/ \ /
/ \ /
/ \ /
/ Concat
/ |
Output1 Output2 (present)
This kernel creates a DML_GRAPH, as mentioned above.
For reference, refer to this Doc:
Expand Down

0 comments on commit 81572aa

Please sign in to comment.