From c34c28854c6cf278ac35f7283f1f50978eda6726 Mon Sep 17 00:00:00 2001 From: Anagha Rao Date: Fri, 8 Mar 2024 14:43:20 -0800 Subject: [PATCH] Update QAttention reference diagram --- .../src/Operators/DmlOperatorQAttention.cpp | 75 ++++++++++--------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index 1f2a7eb6ee05c..f9519b26bb4e3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -6,43 +6,44 @@ /* Abbreviations: B is batch_size, S is sequence_length, W is hidden_size N is number of attention heads, H is head size, and W=N*H - M is mask_index tensor - -M, A, B, C and P are Inputs - - M A B C - | | | / - | MatMulIntToFloat - | / | \ - | / | \ - | / | \ - | Slice Slice Slice - | | | | - | | | | - | Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while - | | | | // keeping the GEMM strides as NCHW to better target metacommands - | | | | - | | | | P - | | | | / \ - | | | | / \ - | | | | Slice Slice - | | | | | | - | | | | | | - | | | | | | - --------------------------MHA ----------- - / | \ - / | \ - / | \ - / | \ - / | \ - / | \ - / presentKey presentValue - / \ / - / \ / - / \ / - / Concat - / | - Output1 Output2 (present) + +Input, Weight, Bias, Mask Index and Past are Inputs + +Mask Index/Causal Input Weight Bias + | \ | / + | \ | / + | \ | / + | MatMulIntToFloat + | / | \ + | / | \ + | / | \ + | Slice Slice Slice + | | | | + | | | | + | Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while + | | | | // keeping the GEMM strides as NCHW to better target metacommands + | | | | + | | | | Past + | | | | / \ + | | | | / \ + | | | | Slice Slice + | | | | | | + | | | | | | + | | | | | | + --------------------------MHA ----------- + / | \ + / | \ + / | \ + / | \ + / | \ + / | \ + / presentKey presentValue + / \ / + / \ / + / \ / + / Concat + / | + Output1 Output2 (present) This kernel creates a DML_GRAPH, as mentioned above. For reference, refer to this Doc: