From de93f40240459953a6e3bbb86b6ad83eaeab681f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 14 Oct 2024 14:49:37 -0700 Subject: [PATCH] [CUDA] Lean Attention (#22352) ### Description Add [Lean Attention](https://arxiv.org/abs/2405.10480) and the integration with MultiHeadAttention operator for LLM in GPU. LeanAttention speeds up self-attention for the token-generation phase (decode-phase) of decoder-only transformer models, especially on long context lengths. - [x] Initial implementation of Lean Attention (by Srikant Bharadwaj) - [x] Integration with MultiHeadAttention operator - [x] Add parity tests - [x] Add benchmark #### Implementation Details (1) Lean Attention is enabled in build for Linux, and disabled for Windows (2) Lean Attention is disabled by default. Need enable it through cuda provider option sdpa_kernel, or use environment variable `ORT_ENABLE_LEAN_ATTENTION=1` (3) It only works for token-generation (sequence_length==1, past_sequence_length > 0). (4) Like flash attention, it only works in Ampere or newer GPU. We can revisit #1 and #2 after comparing with DecoderMaskedMultiHeadAttention and XQA kernels. #### Benchmark ``` cd onnxruntime/test/python/transformers /bin/bash benchmark_mha.sh lean ``` Example outputs in H100: Note that past and present does not share buffer for MHA for now, so we can see low tflops. The relative ratio will change after buffer sharing is enabled. But we expect that the order (kernel A is faster than B) will remain the same after buffer sharing is enabled. Note that common settings `sequence_length=1; causal=True;attn_bias=None;cuda_graph=False` are not shown in the below table. batch_size | past_sequence_length | num_heads | head_size | average_latency | tflops | kernel -- | -- | -- | -- | -- | -- | -- 1 | 512 | 16 | 64 | 0.000059 | 0.0178 | ort:flash 1 | 512 | 16 | 64 | 0.000068 | 0.0155 | ort:efficient 1 | 512 | 16 | 64 | 0.000065 | 0.0161 | ort:math 1 | 512 | 16 | 64 | 0.000060 | 0.0176 | ort:lean 1 | 512 | 32 | 128 | 0.000062 | 0.0674 | ort:flash 1 | 512 | 32 | 128 | 0.000064 | 0.0661 | ort:efficient 1 | 512 | 32 | 128 | 0.000067 | 0.0625 | ort:math 1 | 512 | 32 | 128 | 0.000062 | 0.0678 | ort:lean 1 | 1024 | 16 | 64 | 0.000061 | 0.0345 | ort:flash 1 | 1024 | 16 | 64 | 0.000086 | 0.0244 | ort:efficient 1 | 1024 | 16 | 64 | 0.000065 | 0.0322 | ort:math 1 | 1024 | 16 | 64 | 0.000063 | 0.0332 | ort:lean 1 | 1024 | 32 | 128 | 0.000075 | 0.1125 | ort:flash 1 | 1024 | 32 | 128 | 0.000088 | 0.0951 | ort:efficient 1 | 1024 | 32 | 128 | 0.000079 | 0.1068 | ort:math 1 | 1024 | 32 | 128 | 0.000072 | 0.1171 | ort:lean 1 | 2048 | 16 | 64 | 0.000069 | 0.0606 | ort:flash 1 | 2048 | 16 | 64 | 0.000125 | 0.0336 | ort:efficient 1 | 2048 | 16 | 64 | 0.000064 | 0.0655 | ort:lean 1 | 2048 | 32 | 128 | 0.000098 | 0.1720 | ort:flash 1 | 2048 | 32 | 128 | 0.000132 | 0.1270 | ort:efficient 1 | 2048 | 32 | 128 | 0.000092 | 0.1828 | ort:lean 1 | 4096 | 16 | 64 | 0.000076 | 0.1097 | ort:flash 1 | 4096 | 16 | 64 | 0.000207 | 0.0406 | ort:efficient 1 | 4096 | 16 | 64 | 0.000069 | 0.1209 | ort:lean 1 | 4096 | 32 | 128 | 0.000140 | 0.2394 | ort:flash 1 | 4096 | 32 | 128 | 0.000213 | 0.1575 | ort:efficient 1 | 4096 | 32 | 128 | 0.000139 | 0.2419 | ort:lean 1 | 8192 | 16 | 64 | 0.000104 | 0.1609 | ort:flash 1 | 8192 | 16 | 64 | 0.000392 | 0.0428 | ort:efficient 1 | 8192 | 16 | 64 | 0.000093 | 0.1809 | ort:lean 1 | 8192 | 32 | 128 | 0.000212 | 0.3160 | ort:flash 1 | 8192 | 32 | 128 | 0.000360 | 0.1866 | ort:efficient 1 | 8192 | 32 | 128 | 0.000212 | 0.3162 | ort:lean 1 | 16384 | 16 | 64 | 0.000139 | 0.2410 | ort:flash 1 | 16384 | 16 | 64 | 0.000731 | 0.0459 | ort:efficient 1 | 16384 | 16 | 64 | 0.000136 | 0.2465 | ort:lean 1 | 16384 | 32 | 128 | 0.000361 | 0.3722 | ort:flash 1 | 16384 | 32 | 128 | 0.000667 | 0.2014 | ort:efficient 1 | 16384 | 32 | 128 | 0.000357 | 0.3765 | ort:lean 1 | 32768 | 16 | 64 | 0.000210 | 0.3194 | ort:flash 1 | 32768 | 16 | 64 | 0.001428 | 0.0470 | ort:efficient 1 | 32768 | 16 | 64 | 0.000209 | 0.3211 | ort:lean 1 | 32768 | 32 | 128 | 0.000659 | 0.4074 | ort:flash 1 | 32768 | 32 | 128 | 0.001270 | 0.2114 | ort:efficient 1 | 32768 | 32 | 128 | 0.000651 | 0.4123 | ort:lean 1 | 65536 | 16 | 64 | 0.000355 | 0.3785 | ort:flash 1 | 65536 | 16 | 64 | 0.002736 | 0.0491 | ort:efficient 1 | 65536 | 16 | 64 | 0.000349 | 0.3845 | ort:lean 1 | 65536 | 32 | 128 | 0.001251 | 0.4290 | ort:flash 1 | 65536 | 32 | 128 | 0.002480 | 0.2165 | ort:efficient 1 | 65536 | 32 | 128 | 0.001239 | 0.4333 | ort:lean 4 | 512 | 16 | 64 | 0.000063 | 0.0665 | ort:flash 4 | 512 | 16 | 64 | 0.000069 | 0.0607 | ort:efficient 4 | 512 | 16 | 64 | 0.000066 | 0.0634 | ort:math 4 | 512 | 16 | 64 | 0.000062 | 0.0674 | ort:lean 4 | 512 | 32 | 128 | 0.000100 | 0.1677 | ort:flash 4 | 512 | 32 | 128 | 0.000099 | 0.1703 | ort:efficient 4 | 512 | 32 | 128 | 0.000108 | 0.1557 | ort:math 4 | 512 | 32 | 128 | 0.000092 | 0.1818 | ort:lean 4 | 1024 | 16 | 64 | 0.000077 | 0.1094 | ort:flash 4 | 1024 | 16 | 64 | 0.000099 | 0.0850 | ort:efficient 4 | 1024 | 16 | 64 | 0.000081 | 0.1038 | ort:math 4 | 1024 | 16 | 64 | 0.000072 | 0.1161 | ort:lean 4 | 1024 | 32 | 128 | 0.000143 | 0.2343 | ort:flash 4 | 1024 | 32 | 128 | 0.000137 | 0.2447 | ort:efficient 4 | 1024 | 32 | 128 | 0.000150 | 0.2245 | ort:math 4 | 1024 | 32 | 128 | 0.000135 | 0.2496 | ort:lean 4 | 2048 | 16 | 64 | 0.000096 | 0.1757 | ort:flash 4 | 2048 | 16 | 64 | 0.000156 | 0.1078 | ort:efficient 4 | 2048 | 16 | 64 | 0.000089 | 0.1892 | ort:lean 4 | 2048 | 32 | 128 | 0.000223 | 0.3010 | ort:flash 4 | 2048 | 32 | 128 | 0.000217 | 0.3101 | ort:efficient 4 | 2048 | 32 | 128 | 0.000209 | 0.3209 | ort:lean 4 | 4096 | 16 | 64 | 0.000137 | 0.2448 | ort:flash 4 | 4096 | 16 | 64 | 0.000256 | 0.1312 | ort:efficient 4 | 4096 | 16 | 64 | 0.000133 | 0.2530 | ort:lean 4 | 4096 | 32 | 128 | 0.000389 | 0.3450 | ort:flash 4 | 4096 | 32 | 128 | 0.000376 | 0.3574 | ort:efficient 4 | 4096 | 32 | 128 | 0.000354 | 0.3794 | ort:lean 4 | 8192 | 16 | 64 | 0.000210 | 0.3198 | ort:flash 4 | 8192 | 16 | 64 | 0.000453 | 0.1480 | ort:efficient 4 | 8192 | 16 | 64 | 0.000206 | 0.3260 | ort:lean 4 | 8192 | 32 | 128 | 0.000725 | 0.3705 | ort:flash 4 | 8192 | 32 | 128 | 0.000693 | 0.3874 | ort:efficient 4 | 8192 | 32 | 128 | 0.000653 | 0.4114 | ort:lean 4 | 16384 | 16 | 64 | 0.000355 | 0.3782 | ort:flash 4 | 16384 | 16 | 64 | 0.000849 | 0.1581 | ort:efficient 4 | 16384 | 16 | 64 | 0.000346 | 0.3874 | ort:lean 4 | 16384 | 32 | 128 | 0.001395 | 0.3848 | ort:flash 4 | 16384 | 32 | 128 | 0.001337 | 0.4017 | ort:efficient 4 | 16384 | 32 | 128 | 0.001252 | 0.4288 | ort:lean 4 | 32768 | 16 | 64 | 0.000647 | 0.4146 | ort:flash 4 | 32768 | 16 | 64 | 0.001649 | 0.1628 | ort:efficient 4 | 32768 | 16 | 64 | 0.000639 | 0.4204 | ort:lean 4 | 32768 | 32 | 128 | 0.002721 | 0.3947 | ort:flash 4 | 32768 | 32 | 128 | 0.002601 | 0.4128 | ort:efficient 4 | 32768 | 32 | 128 | 0.002434 | 0.4411 | ort:lean 4 | 65536 | 16 | 64 | 0.001231 | 0.4361 | ort:flash 4 | 65536 | 16 | 64 | 0.003238 | 0.1658 | ort:efficient 4 | 65536 | 16 | 64 | 0.001217 | 0.4412 | ort:lean 4 | 65536 | 32 | 128 | 0.005357 | 0.4009 | ort:flash 4 | 65536 | 32 | 128 | 0.005118 | 0.4196 | ort:efficient 4 | 65536 | 32 | 128 | 0.004781 | 0.4492 | ort:lean 16 | 512 | 16 | 64 | 0.000098 | 0.1724 | ort:flash 16 | 512 | 16 | 64 | 0.000104 | 0.1616 | ort:efficient 16 | 512 | 16 | 64 | 0.000118 | 0.1420 | ort:math 16 | 512 | 16 | 64 | 0.000087 | 0.1926 | ort:lean 16 | 512 | 32 | 128 | 0.000220 | 0.3062 | ort:flash 16 | 512 | 32 | 128 | 0.000208 | 0.3237 | ort:efficient 16 | 512 | 32 | 128 | 0.000237 | 0.2838 | ort:math 16 | 512 | 32 | 128 | 0.000209 | 0.3216 | ort:lean 16 | 1024 | 16 | 64 | 0.000136 | 0.2465 | ort:flash 16 | 1024 | 16 | 64 | 0.000150 | 0.2235 | ort:efficient 16 | 1024 | 16 | 64 | 0.000148 | 0.2266 | ort:math 16 | 1024 | 16 | 64 | 0.000129 | 0.2611 | ort:lean 16 | 1024 | 32 | 128 | 0.000367 | 0.3663 | ort:flash 16 | 1024 | 32 | 128 | 0.000351 | 0.3829 | ort:efficient 16 | 1024 | 32 | 128 | 0.000400 | 0.3357 | ort:math 16 | 1024 | 32 | 128 | 0.000349 | 0.3853 | ort:lean 16 | 2048 | 16 | 64 | 0.000209 | 0.3206 | ort:flash 16 | 2048 | 16 | 64 | 0.000243 | 0.2762 | ort:efficient 16 | 2048 | 16 | 64 | 0.000201 | 0.3338 | ort:lean 16 | 2048 | 32 | 128 | 0.000671 | 0.4002 | ort:flash 16 | 2048 | 32 | 128 | 0.000645 | 0.4163 | ort:efficient 16 | 2048 | 32 | 128 | 0.000642 | 0.4185 | ort:lean 16 | 4096 | 16 | 64 | 0.000360 | 0.3732 | ort:flash 16 | 4096 | 16 | 64 | 0.000425 | 0.3162 | ort:efficient 16 | 4096 | 16 | 64 | 0.000341 | 0.3933 | ort:lean 16 | 4096 | 32 | 128 | 0.001292 | 0.4156 | ort:flash 16 | 4096 | 32 | 128 | 0.001251 | 0.4291 | ort:efficient 16 | 4096 | 32 | 128 | 0.001241 | 0.4327 | ort:lean 16 | 8192 | 16 | 64 | 0.000666 | 0.4030 | ort:flash 16 | 8192 | 16 | 64 | 0.000804 | 0.3339 | ort:efficient 16 | 8192 | 16 | 64 | 0.000627 | 0.4283 | ort:lean 16 | 8192 | 32 | 128 | 0.002541 | 0.4226 | ort:flash 16 | 8192 | 32 | 128 | 0.002454 | 0.4376 | ort:efficient 16 | 8192 | 32 | 128 | 0.002438 | 0.4405 | ort:lean 16 | 16384 | 16 | 64 | 0.001292 | 0.4156 | ort:flash 16 | 16384 | 16 | 64 | 0.001571 | 0.3417 | ort:efficient 16 | 16384 | 16 | 64 | 0.001217 | 0.4411 | ort:lean 16 | 16384 | 32 | 128 | 0.005042 | 0.4260 | ort:flash 16 | 16384 | 32 | 128 | 0.004859 | 0.4420 | ort:efficient 16 | 16384 | 32 | 128 | 0.004827 | 0.4449 | ort:lean 16 | 32768 | 16 | 64 | 0.002537 | 0.4233 | ort:flash 16 | 32768 | 16 | 64 | 0.003103 | 0.3461 | ort:efficient 16 | 32768 | 16 | 64 | 0.002385 | 0.4501 | ort:lean 16 | 32768 | 32 | 128 | 0.009961 | 0.4312 | ort:flash 16 | 32768 | 32 | 128 | 0.009605 | 0.4472 | ort:efficient 16 | 32768 | 32 | 128 | 0.009524 | 0.4510 | ort:lean 16 | 65536 | 16 | 64 | 0.005019 | 0.4279 | ort:flash 16 | 65536 | 16 | 64 | 0.006133 | 0.3502 | ort:efficient 16 | 65536 | 16 | 64 | 0.004703 | 0.4566 | ort:lean 16 | 65536 | 32 | 128 | 0.019746 | 0.4350 | ort:flash 16 | 65536 | 32 | 128 | 0.019027 | 0.4515 | ort:efficient 16 | 65536 | 32 | 128 | 0.018864 | 0.4554 | ort:lean ### Motivation and Context --- cmake/CMakeLists.txt | 17 + .../contrib_ops/cpu/bert/attention_common.h | 10 +- .../contrib_ops/cuda/bert/attention.cc | 18 +- .../contrib_ops/cuda/bert/attention_impl.cu | 96 +- .../contrib_ops/cuda/bert/attention_impl.h | 15 + .../cuda/bert/attention_kernel_options.cc | 17 + .../cuda/bert/attention_kernel_options.h | 3 + .../cuda/bert/attention_prepare_qkv.cu | 1 + .../cuda/bert/lean_attention/block_info.h | 45 + .../cuda/bert/lean_attention/flash.h | 148 +++ .../cuda/bert/lean_attention/kernel_traits.h | 315 +++++ .../cuda/bert/lean_attention/lean_api.cc | 453 +++++++ .../cuda/bert/lean_attention/lean_api.h | 64 + .../lean_attention/lean_fwd_hdim128_fp16.cu | 15 + .../lean_attention/lean_fwd_hdim64_fp16.cu | 15 + .../bert/lean_attention/lean_fwd_kernel.h | 1066 +++++++++++++++++ .../lean_attention/lean_fwd_launch_template.h | 73 ++ .../cuda/bert/lean_attention/mask.h | 209 ++++ .../cuda/bert/lean_attention/softmax.h | 196 +++ .../cuda/bert/lean_attention/static_switch.h | 109 ++ .../cuda/bert/lean_attention/utils.h | 411 +++++++ .../cuda/bert/multihead_attention.cc | 98 +- .../cuda/bert/multihead_attention.h | 3 + .../quantization/attention_quantization.cc | 2 + .../test/python/transformers/benchmark_mha.py | 77 +- .../test/python/transformers/benchmark_mha.sh | 119 +- .../test/python/transformers/test_mha.py | 51 + 27 files changed, 3578 insertions(+), 68 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/block_info.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/flash.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/kernel_traits.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.cc create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/static_switch.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/lean_attention/utils.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ef208f59f63b0..d90a2a355045e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -106,6 +106,7 @@ option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) +cmake_dependent_option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA; NOT WIN32" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) @@ -751,21 +752,30 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_DISABLE_CONTRIB_OPS) set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_LEAN_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_LEAN_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) elseif(WIN32 AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) message( STATUS "Flash-Attention unsupported in Windows with CUDA compiler version < 12.0") set(onnxruntime_USE_FLASH_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") endif() + if (WIN32) + message( STATUS "Lean Attention unsupported in Windows") + set(onnxruntime_USE_LEAN_ATTENTION OFF) + endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_LEAN_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() @@ -779,6 +789,13 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1) endif() + + if (onnxruntime_USE_LEAN_ATTENTION) + message( STATUS "Enable lean attention for CUDA EP") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_LEAN_ATTENTION=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_LEAN_ATTENTION=1) + endif() + if (onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) message( STATUS "Enable memory efficient attention for CUDA EP") list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 46638555576a9..97d6cc1ce7d66 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -48,6 +48,7 @@ enum AttentionKernelType { AttentionKernel_CutlassMemoryEfficientAttention, AttentionKernel_FlashAttention, AttentionKernel_CudnnFlashAttention, + AttentionKernel_LeanAttention, AttentionKernel_Default }; @@ -65,7 +66,6 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; - int num_splits; int rotary_embedding; bool is_unidirectional; bool past_present_share_buffer; @@ -208,10 +208,13 @@ enum class AttentionBackend : int { CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention. MATH = 16, // unfused kernel cannot be disabled right now. - // The following kernels might be deprecated in the future. + // The following TRT kernels might be deprecated in the future. TRT_FLASH_ATTENTION = 32, TRT_CROSS_ATTENTION = 64, TRT_CAUSAL_ATTENTION = 128, + + // Experimental kernels + LEAN_ATTENTION = 256, }; // Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled). @@ -239,6 +242,9 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF // Environment variable to enable or disable flash attention. Default is 0 (enabled). constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; +// Environment variable to enable or disable lean attention. Default is 0 (disabled). +constexpr const char* kEnableLeanAttention = "ORT_ENABLE_LEAN_ATTENTION"; + // Minimum sequence length to perfer memory efficient attention when data type is float32 constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32"; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index efbc0b5031657..22e2879a5be15 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -102,6 +102,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const int sm = device_prop.major * 10 + device_prop.minor; const bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + typedef typename ToCudaType::MappedType CudaT; + AttentionData data; + #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && (nullptr == attention_bias) && @@ -118,21 +121,26 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_flash_attention = false; } // Allocate buffers + size_t softmax_lse_bytes = 0; size_t softmax_lse_accum_bytes = 0; size_t out_accum_bytes = 0; if (use_flash_attention) { + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, parameters.num_heads); + using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); - parameters.num_splits = static_cast(num_splits); + data.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; } + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif @@ -247,6 +255,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { constexpr size_t element_size = sizeof(T); constexpr bool use_fused_cross_attention = false; constexpr bool use_cudnn_flash_attention = false; + constexpr bool use_lean_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, parameters.batch_size, parameters.num_heads, @@ -257,14 +266,13 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.total_sequence_length, fused_runner, use_flash_attention, + use_lean_attention, use_fused_cross_attention, use_memory_efficient_attention, use_cudnn_flash_attention, false); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); - typedef typename ToCudaType::MappedType CudaT; - AttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); if (nullptr != bias) { data.bias = reinterpret_cast(bias->Data()); @@ -289,6 +297,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + if (softmax_lse_buffer != nullptr) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + } + if (softmax_lse_accum_buffer != nullptr) { data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index eff58c0080012..9e017544d7cff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -39,6 +39,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/lean_attention/lean_api.h" #include "contrib_ops/cuda/bert/attention_impl.h" using namespace onnxruntime::cuda; @@ -108,6 +109,7 @@ size_t GetAttentionWorkspaceSize( size_t total_sequence_length, void* fused_runner, bool use_flash_attention, + bool use_lean_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention, bool use_cudnn_flash_attention, @@ -119,12 +121,20 @@ size_t GetAttentionWorkspaceSize( #if USE_FLASH_ATTENTION if (use_flash_attention) { - return qkv_bytes + onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads); + return qkv_bytes; } #else ORT_UNUSED_PARAMETER(use_flash_attention); #endif +#if USE_LEAN_ATTENTION + if (use_lean_attention) { + return qkv_bytes; + } +#else + ORT_UNUSED_PARAMETER(use_lean_attention); +#endif + #if USE_MEMORY_EFFICIENT_ATTENTION if (use_memory_efficient_attention) { size_t fmha_buffer_bytes = 0; @@ -301,10 +311,10 @@ Status FlashAttention( constexpr bool is_bf16 = false; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.scratch), + device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.softmax_lse), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, parameters.sequence_length, parameters.total_sequence_length, scale, 0.0, parameters.is_unidirectional, is_bf16, - false, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + false, data.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH)); return Status::OK(); @@ -326,6 +336,81 @@ Status FlashAttention( } #endif +#if USE_LEAN_ATTENTION +template +Status LeanAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + assert(nullptr == data.mask_index); + assert(nullptr == data.attention_bias); + assert(parameters.head_size == parameters.v_head_size); + + constexpr bool is_bf16 = false; + + ORT_RETURN_IF_ERROR(onnxruntime::lean::mha_fwd_kvcache( + device_prop, stream, + data.q, + data.k, // k_cache + data.v, // v_cache + nullptr, // new_k (we have appended new_k to k_cache) + nullptr, // new_v (we have appended new_v to k_cache) + data.output, + reinterpret_cast(data.softmax_lse), + nullptr, // seqlens_k + nullptr, // cos_cache + nullptr, // sin_cache + nullptr, // block_table + parameters.batch_size, + parameters.num_heads, + parameters.num_heads, // num_heads_k + parameters.head_size, + parameters.sequence_length, // seqlen_q + parameters.total_sequence_length, // seqlen_k + 0, // seqlen_k_new + 0, // rotary_dim + scale, // softmax_scale + parameters.is_unidirectional, + is_bf16, + false, // past_bsnh + data.num_splits, + data.grid_dim_z, + data.max_tiles_per_tb, + data.high_load_tbs, + data.tiles_per_head, + reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), + data.lean_sync_flag, + -1, // local_window_size + false, // is_rotary_interleaved + false // is_packed_qkv + )); + + return Status::OK(); +} + +template <> +Status LeanAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + ORT_UNUSED_PARAMETER(device_prop); + ORT_UNUSED_PARAMETER(stream); + ORT_UNUSED_PARAMETER(parameters); + ORT_UNUSED_PARAMETER(data); + ORT_UNUSED_PARAMETER(scale); + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "lean attention does not support float tensor"); +} +#endif + + + template Status CudnnFlashAttention( cudnnHandle_t cudnn_handle, @@ -641,6 +726,11 @@ Status QkvToContext( // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) : parameters.scale; +#if USE_LEAN_ATTENTION + if (data.use_lean_attention) { + return LeanAttention(device_prop, stream, parameters, data, scale); + } +#endif #if USE_FLASH_ATTENTION if (data.use_flash_attention) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index fcc9af9681223..7d111a1ee21bf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -53,6 +53,7 @@ size_t GetAttentionWorkspaceSize( size_t total_sequence_length, void* fused_runner, bool use_flash_attention, + bool use_lean_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention, bool use_cudnn_flash_attention, @@ -102,6 +103,19 @@ struct AttentionData { T* softmax_lse_accum = nullptr; T* out_accum = nullptr; + // Flash Atttention and Lean Attention + int num_splits; + + // Lean Attention + bool use_lean_attention = false; +#if USE_LEAN_ATTENTION + int grid_dim_z = 0; + int max_tiles_per_tb = 0; + int high_load_tbs = 0; + int tiles_per_head = 0; + int* lean_sync_flag = nullptr; +#endif + // For Debugging size_t workspace_bytes = 0; bool allow_debug_info = false; @@ -115,6 +129,7 @@ struct AttentionData { void PrintDebugInfo() const { std::cout << "flash=" << use_flash_attention + << ", lean=" << use_lean_attention << ", efficient=" << use_memory_efficient_attention << ", fused_runner=" << (fused_runner != nullptr) << ", fused_cross=" << (fused_cross_attention_kernel != nullptr) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc index 7d21451df5b86..8b8b764e7c785 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc @@ -17,6 +17,9 @@ namespace onnxruntime { void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) { if (value > 0) { use_flash_attention_ = (value & static_cast(AttentionBackend::FLASH_ATTENTION)) > 0; +#if USE_LEAN_ATTENTION + use_lean_attention_ = (value & static_cast(AttentionBackend::LEAN_ATTENTION)) > 0; +#endif use_efficient_attention_ = (value & static_cast(AttentionBackend::EFFICIENT_ATTENTION)) > 0; use_trt_fused_attention_ = (value & static_cast(AttentionBackend::TRT_FUSED_ATTENTION)) > 0; use_cudnn_flash_attention_ = (value & static_cast(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0; @@ -26,6 +29,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che use_trt_causal_attention_ = (value & static_cast(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0; } else { use_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFlashAttention, false); +#if USE_LEAN_ATTENTION + use_lean_attention_ = ParseEnvironmentVariableWithDefault(kEnableLeanAttention, false); +#endif use_efficient_attention_ = !ParseEnvironmentVariableWithDefault(kDisableMemoryEfficientAttention, false); use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedSelfAttention, false); use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault(kEnableCudnnFlashAttention, false); @@ -61,6 +67,10 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che use_flash_attention_ = false; #endif +#ifndef USE_LEAN_ATTENTION + use_lean_attention_ = false; +#endif + #ifndef USE_MEMORY_EFFICIENT_ATTENTION use_efficient_attention_ = false; #endif @@ -81,6 +91,9 @@ void AttentionKernelOptions::Print() const { std::stringstream sstream; sstream << "AttentionKernelOptions:"; sstream << " FLASH_ATTENTION=" << int(use_flash_attention_); +#if USE_LEAN_ATTENTION + sstream << " LEAN_ATTENTION=" << int(use_lean_attention_); +#endif sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_); sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_); sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_); @@ -131,6 +144,10 @@ void AttentionKernelDebugInfo::Print(const char* operator_name, sstream << " SdpaKernel="; if (use_flash_attention.has_value() && use_flash_attention.value()) { sstream << "FLASH_ATTENTION"; +#if USE_LEAN_ATTENTION + } else if (use_lean_attention.has_value() && use_lean_attention.value()) { + sstream << "LEAN_ATTENTION"; +#endif } else if (use_efficient_attention.has_value() && use_efficient_attention.value()) { sstream << "EFFICIENT_ATTENTION"; } else if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h index a27fb199a6272..caed704564c3b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -9,6 +9,7 @@ namespace onnxruntime { struct AttentionKernelDebugInfo { std::optional use_flash_attention = std::nullopt; + std::optional use_lean_attention = std::nullopt; std::optional use_efficient_attention = std::nullopt; std::optional use_trt_fused_attention = std::nullopt; std::optional use_cudnn_flash_attention = std::nullopt; @@ -24,6 +25,7 @@ class AttentionKernelOptions { void InitializeOnce(int sdpa_kernel, bool use_build_flag, bool check_cudnn_version = false); bool UseFlashAttention() const { return use_flash_attention_; } + bool UseLeanAttention() const { return use_lean_attention_; } bool UseEfficientAttention() const { return use_efficient_attention_; } bool UseTrtFusedAttention() const { return use_trt_fused_attention_; } bool UseCudnnFlashAttention() const { return use_cudnn_flash_attention_; } @@ -44,6 +46,7 @@ class AttentionKernelOptions { private: bool use_flash_attention_{true}; + bool use_lean_attention_{false}; bool use_efficient_attention_{true}; bool use_trt_fused_attention_{true}; bool use_cudnn_flash_attention_{false}; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index a079076f2881b..c8c0191967d40 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -384,6 +384,7 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, if (data.use_memory_efficient_attention || data.use_flash_attention || + data.use_lean_attention || data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Use oiginal Query (BSNH) since there is no bias. data.q = const_cast(data.query); diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/block_info.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/block_info.h new file mode 100644 index 0000000000000..6d9ed824b4b76 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/block_info.h @@ -0,0 +1,45 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace onnxruntime { +namespace lean { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + template + __device__ BlockInfo(const Params& params, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]), actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , + seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])), + actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { + } + + template + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/flash.h new file mode 100644 index 0000000000000..a2058d8805ebd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/flash.h @@ -0,0 +1,148 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +namespace onnxruntime { +namespace lean { + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; + void* __restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + // The O matrix (output). + void* __restrict__ o_ptr; + void* __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void* __restrict__ p_ptr; + + // The pointer to the softmax sum. + void* __restrict__ softmax_lse_ptr; + void* __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int* __restrict__ cu_seqlens_q; + int* __restrict__ cu_seqlens_k; + + // If provided, the actual length of each k sequence. + int* __restrict__ seqused_k; + + int* __restrict__ blockmask; + + // The K_new and V_new matrices. + void* __restrict__ knew_ptr; + void* __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr; + void* __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx; + + // Paged KV cache + int* __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Local window size + int window_size_left, window_size_right; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t* rng_state; + + bool is_bf16; + bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version and lean + + void* __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; + + // LEAN Additional Params + int lean_griddimz; + int tiles_per_head; + int max_tiles_per_tb; + int high_load_tbs; + void* __restrict__ sync_flag; + + const cudaDeviceProp* dprops = nullptr; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_lean_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/kernel_traits.h new file mode 100644 index 0000000000000..85be5d3e031ac --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/kernel_traits.h @@ -0,0 +1,315 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = int64_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom>; +#else + using MMA_Atom_Arch = MMA_Atom; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, _1, _1>>, // 4x1x1 or 8x1x1 thread group + Tile, _16, _16>>; + + using SmemLayoutAtomQ = decltype(composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + using SmemLayoutVtransposed = decltype(composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + + using SmemLayoutAtomO = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; + + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(ElementAccum); + // static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize + kSmemOSize; + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride<_8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride<_16, _1>>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group + Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; + + using SmemLayoutAtomQdO = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKV = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutKtransposed = decltype(composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + // Temporarily disabling this for hdim 256 on sm86 and sm89 + // static_assert(kBlockN >= 64); + static_assert(kBlockN >= 32); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposed = decltype(composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); + + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutQdOtransposed = decltype(composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); + + using SmemLayoutAtomdKV = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + // Double buffer for sQ + static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride<_8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride<_16, _1>>>; + using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom{}, + Layout, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout>{})); // Val layout, 1 val per store +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.cc b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.cc new file mode 100644 index 0000000000000..81301ebc7ba64 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.cc @@ -0,0 +1,453 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Modifications: support lean attention. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_LEAN_ATTENTION + +#include "contrib_ops/cuda/bert/lean_attention/lean_api.h" +#include + +#include "contrib_ops/cuda/bert/lean_attention/flash.h" +#include "contrib_ops/cuda/bert/lean_attention/static_switch.h" + +namespace onnxruntime { +namespace lean { + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +void set_params_fprop(Flash_fwd_params& params, + // sizes + size_t batch_size, + size_t seqlen_q, + size_t seqlen_k, + size_t seqlen_q_rounded, + size_t seqlen_k_rounded, + size_t num_heads, + size_t num_heads_k, + size_t head_size, + size_t head_size_rounded, + // device pointers + void* q, + void* k, + void* v, + void* out, + void* cu_seqlens_q_d, + void* cu_seqlens_k_d, + void* seqused_k, + void* p_d, + void* softmax_lse_d, + float softmax_scale, + bool is_causal, + bool is_bf16, + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { + // Set the pointers and strides. + params.q_ptr = q; + params.k_ptr = k; + params.v_ptr = v; + params.o_ptr = out; + + params.is_bf16 = is_bf16; + + // All stride are in elements, not bytes. + if (kv_bsnh) { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = num_heads_k * head_size; + params.v_row_stride = num_heads_k * head_size; + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } else { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = head_size; + params.v_row_stride = head_size; + params.q_head_stride = head_size; + params.k_head_stride = seqlen_k * head_size; + params.v_head_stride = seqlen_k * head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + } else { + params.q_batch_stride = 0; + params.k_batch_stride = 0; + params.v_batch_stride = 0; + params.o_batch_stride = 0; + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4267) // Ignore conversion from 'size_t' to 'int', possible loss of data +#pragma warning(disable : 4244) // Ignore conversion from 'double' to 'float', possible loss of data +#endif + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.h_h_k_ratio = num_heads / num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API separates + // local and causal, meaning when we have local window size + params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; +} + +size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) { + size_t bytes = sizeof(float) * batch_size * num_heads * seqlen; + return bytes; +} + +size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; + return bytes; +} + +size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, + size_t seqlen_q, size_t head_size_rounded) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded; + return bytes; +} + +size_t get_sync_flag_size(size_t num_m_blocks, size_t batch_size, size_t num_heads) { + size_t bytes = sizeof(int) * batch_size * num_heads * num_m_blocks; + return bytes; +} + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_lean_dispatch(params, stream); + }); + }); +} + +std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t max_seqlen_q, size_t max_seqlen_k, + size_t num_heads, size_t num_heads_k, size_t head_size, size_t num_SMs, bool is_causal) { + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int block_m = head_size <= 64 ? 64 : (head_size <= 128 ? 64 : 64); + const int num_m_blocks = (max_seqlen_q + block_m - 1) / block_m; + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + if (max_seqlen_q == 1) { + is_causal = false; + } + + max_seqlen_q = max_seqlen_q * num_heads / num_heads_k; + +#if defined(DEBUG_LEAN_ATTENTION) + printf("block_n: %d\n", block_n); + printf("block_m: %d\n", block_m); + printf("num_m_blocks: %d\n", num_m_blocks); + printf("num_n_blocks: %d\n", num_n_blocks); + printf("max_seqlen_q: %lu\n", max_seqlen_q); + printf("max_seqlen_k: %lu\n", max_seqlen_k); + printf("is_causal: %d\n", is_causal); + printf("num_heads: %lu\n", num_heads); + printf("num_heads_k: %lu\n", num_heads_k); +#endif + + size_t tiles_per_head = 0; + if (is_causal) { + // Prefill - Causal + for (int i = 0; i < num_m_blocks; i++) { + tiles_per_head += (((i + 1) * block_m) + block_n - 1) / block_n; + } + } else { + // Decode or Not Causal + // Tiles per head is the number of blocks in the first block + tiles_per_head = num_m_blocks * num_n_blocks; + } + size_t total_tiles = tiles_per_head * batch_size * num_heads_k; + + // StreamK Lean has as many threadblocks as SMs + // This should be a function of tile size and number of scratchpad space + + // We want at least two tiles per CTA to be efficient + // And then 2 CTAs per SM + size_t lean_griddimz = num_SMs * 2; + if (total_tiles <= 2 * 2 * num_SMs) { + lean_griddimz = std::min((total_tiles + 1) / 2, (32 * total_tiles + num_n_blocks - 1) / num_n_blocks); + // params.lean_griddimz = num_m_blocks * batch_size * num_heads; + } else { + // Max split of 64 per block is allowed, so we conservatively set it to 32 + // to account for ceil + lean_griddimz = std::min(2 * num_SMs, 32 * num_heads_k * batch_size * num_m_blocks); + } + size_t max_tiles_per_tb = (total_tiles + lean_griddimz - 1) / lean_griddimz; + // Find max number of splits + size_t num_splits = 0; + if (total_tiles % lean_griddimz == 0) { + num_splits = 1 + ((num_n_blocks + max_tiles_per_tb - 2) / (max_tiles_per_tb)); + } else { + num_splits = 1 + ((num_n_blocks + max_tiles_per_tb - 3) / (max_tiles_per_tb - 1)); + } + size_t high_load_tbs = total_tiles - ((max_tiles_per_tb - 1) * lean_griddimz); + +#if defined(DEBUG_LEAN_ATTENTION) + printf("Causal: %d params.tiles_per_head : %lu\n", is_causal, tiles_per_head); + printf("num_splits = %lu\n", num_splits); + printf("total_tiles = %lu\n", total_tiles); + printf("lean_griddimz = %lu\n", lean_griddimz); + printf("max_tiles_per_tb = %lu\n", max_tiles_per_tb); + printf("high_load_tbs = %lu\n", high_load_tbs); +#endif + + if (num_splits > 1) { + size_t softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads_k, max_seqlen_q); + auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; }; + const size_t head_size_rounded = round_multiple(head_size, 32); + size_t out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads_k, max_seqlen_q, head_size_rounded); + size_t sync_flag_bytes = get_sync_flag_size(num_m_blocks, batch_size, num_heads_k); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes, sync_flag_bytes, lean_griddimz, max_tiles_per_tb, high_load_tbs, tiles_per_head}; + } else { + return {0, 0, 0, 0, lean_griddimz, max_tiles_per_tb, high_load_tbs, tiles_per_head}; + } +} + +bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k) { + bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; + bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + return (is_sm8x || is_sm90) && (head_size == 64 || head_size == 128) && (num_heads % num_heads_k == 0); +} + +// This API is used when past key and value are present... since cached, these are assumed to have sequence length +// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_. +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + int* block_table, // batch_size x max_num_blocks_per_seq + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + int rotary_dim, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits, + int grid_dimz, + int max_tiles_per_tb, + int high_load_tbs, + int tiles_per_head, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int* sync_flag, + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv, + int max_num_blocks_per_seq, + int page_block_size) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + const bool paged_KV = block_table != nullptr; + +#if defined(DEBUG_LEAN_ATTENTION) + printf( + "batch_size: %d num_heads %d num_heads_k %d head_size %d seqlen_q %d seqlen_k %d seqlen_k_new %d " + "softmax_scale %f is_causal %d is_bf16 %d past_bsnh %d num_splits %d grid_dimz %d max_tiles_per_tb %d " + "high_load_tbs %d tiles_per_head %d local_window_size %d is_rotary_interleaved %d is_packed_qkv %d " + "max_num_blocks_per_seq %d page_block_size %d\n", + batch_size, num_heads, num_heads_k, head_size, seqlen_q, seqlen_k, seqlen_k_new, + softmax_scale, is_causal, is_bf16, past_bsnh, num_splits, grid_dimz, max_tiles_per_tb, + high_load_tbs, tiles_per_head, local_window_size, is_rotary_interleaved, is_packed_qkv, + max_num_blocks_per_seq, page_block_size); +#endif + + // Lean attention treats decode as non-causal + if (seqlen_q == 1) { + is_causal = false; + } + + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_k; + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + // In kv-cache case, seqlen_k_max as kv sequence length + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, kcache, vcache, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + past_bsnh, + local_window_size, + is_causal ? 0 : -1); + params.dprops = &dprops; + + if (k_new != nullptr && v_new != nullptr) { + params.seqlen_knew = seqlen_k_new; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; + // All stride are in elements, not bytes. + params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } + params.knew_head_stride = head_size; + params.vnew_head_stride = head_size; + } else { + params.seqlen_knew = 0; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + } + + if (seqlenq_ngroups_swapped) { + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads_k * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + } else { + params.q_batch_stride = seqlen_q * num_heads_k * head_size; + } + params.q_row_stride = head_size; + params.q_head_stride = seqlen_q * head_size; + params.o_row_stride = head_size; + params.o_head_stride = seqlen_q * head_size; + params.o_batch_stride = seqlen_q * num_heads_k * head_size; + } + + params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; + if (seqlens_k_ != nullptr) { + params.cu_seqlens_k = static_cast(seqlens_k_); + } + + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = rotary_dim; + } + + params.num_splits = num_splits; + params.lean_griddimz = grid_dimz; + params.max_tiles_per_tb = max_tiles_per_tb; + params.high_load_tbs = high_load_tbs; + params.tiles_per_head = tiles_per_head; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + params.sync_flag = sync_flag; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + params.alibi_slopes_ptr = nullptr; + if (paged_KV) { + params.block_table = block_table; // TODO(aciddelgado): cast to int pointer + params.block_table_batch_stride = max_num_blocks_per_seq; + // params.num_blocks = num_blocks; + params.page_block_size = page_block_size; + params.k_batch_stride = page_block_size * num_heads_k * head_size; + params.v_batch_stride = page_block_size * num_heads_k * head_size; + } else { + params.block_table = nullptr; + params.block_table_batch_stride = 0; + // params.num_blocks = 0; + params.page_block_size = 1; + } + + // Only split kernel supports appending to KV cache + run_mha_fwd(params, stream); + + return Status::OK(); +} + +} // namespace lean +} // namespace onnxruntime + +#endif // USE_LEAN_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.h new file mode 100644 index 0000000000000..3b9bd1c24f08c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if USE_LEAN_ATTENTION + +#include "core/providers/cuda/cuda_common.h" +#include + +namespace onnxruntime { +namespace lean { + +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* k, // batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + int* block_table, // batch_size x max_num_blocks_per_seq + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + int rotary_dim, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits = 0, + int grid_dimz = 0, + int max_tiles_per_tb = 0, + int high_load_tbs = 0, + int tiles_per_head = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int* sync_flag = nullptr, + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false, + int max_num_blocks_per_seq = 0, + int page_block_size = 1); + +size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads); + +std::tuple +get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads, + size_t num_heads_k, size_t head_size, size_t num_SMs, bool is_causal); + +bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k); + +} // namespace lean +} // namespace onnxruntime + +#endif // USE_LEAN_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu new file mode 100644 index 0000000000000..cfcacbabb3cb9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_LEAN_ATTENTION + +#include "contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h" + +namespace onnxruntime { +namespace lean { + +template void run_mha_fwd_lean_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu new file mode 100644 index 0000000000000..44c870f6ab35b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_LEAN_ATTENTION + +#include "contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h" + +namespace onnxruntime { +namespace lean { + +template void run_mha_fwd_lean_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h new file mode 100644 index 0000000000000..5be69ea0af55c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h @@ -0,0 +1,1066 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "contrib_ops/cuda/bert/lean_attention/block_info.h" +#include "contrib_ops/cuda/bert/lean_attention/kernel_traits.h" +#include "contrib_ops/cuda/bert/lean_attention/utils.h" +#include "contrib_ops/cuda/bert/lean_attention/softmax.h" +#include "contrib_ops/cuda/bert/lean_attention/mask.h" + +namespace onnxruntime { +namespace lean { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialized for Prefill +template +inline __device__ void lean_compute_attn_impl_ver3(const Params& params, const int cta_id, int start_tile_gid, int start_tile_hid, int num_tiles, const int num_tiles_per_head) { +#if defined(DEBUG_LEAN_ATTENTION) + // Timing + auto kernel_start = clock64(); + long long int comp1_duration = 0; + long long int comp2_duration = 0; + long long int epilogue_duration = 0; + long long int prologue_duration = 0; + long long int epil1_duration = 0; + long long int epil2_duration = 0; + long long int epil3_duration = 0; + + const int tracing_block = 0; +#endif + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = typename Kernel_traits::GmemTiledCopyO; + using GmemTiledCopyOaccum = typename Kernel_traits::GmemTiledCopyOaccum; + + const int num_m_blocks_per_head = (params.seqlen_q + kBlockM - 1) / kBlockM; + + // // This is the solution to the summation series (n+1)(n+2)/2 = start_tile_hid + 1 + // int cur_m_block = Is_causal ? (int)ceilf((sqrtf(9 + (8*start_tile_hid)) - 3) / 2) : start_tile_hid/num_tiles_per_head; + float block_scale = (float)kBlockM / (float)kBlockN; + int cur_m_block = Is_causal ? kBlockM > kBlockN ? (int)ceilf((sqrtf(1 + (8 * start_tile_hid + 8) / block_scale) - 3) / 2) + // : (int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) * (1 / block_scale) + (int)((start_tile_hid - (1 / block_scale) * ((int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) * ((int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) + 1) / 2)) / ((int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) + 1)) + : static_cast((-1 + sqrt(1 + 8 * start_tile_hid * block_scale)) / (2 * block_scale)) + : start_tile_hid / num_tiles_per_head; + int num_tiles_in_block = Is_causal ? (int)ceilf(block_scale * (cur_m_block + 1)) : num_tiles_per_head; + int cur_bidb = start_tile_gid / (num_tiles_per_head * params.h); + int cur_bidh = (start_tile_gid - (cur_bidb * num_tiles_per_head * params.h)) / num_tiles_per_head; + + int num_tiles_left = num_tiles; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Debugging block = %d\n", tracing_block); + printf("kBlockM = %d\n", kBlockM); + printf("kBlockN = %d\n", kBlockN); + printf("kHeadDim = %d\n", kHeadDim); + printf("kNWarps = %d\n", kNWarps); + printf("IsEvenMN = %d\n", Is_even_MN); + printf("block_scale = %f\n", block_scale); + printf("seq_len_q -change = %d\n", params.seqlen_q); + printf("seq_len_k = %d\n", params.seqlen_k); + printf("q_batch_stride = %ld\n", params.q_batch_stride); + printf("q_head_stride = %ld\n", params.q_head_stride); + printf("q_row_stride = %ld\n", params.q_row_stride); + printf("k_batch_stride = %ld\n", params.k_batch_stride); + printf("k_head_stride = %ld\n", params.k_head_stride); + printf("k_row_stride = %ld\n", params.k_row_stride); + printf("v_row_stride = %ld\n", params.v_row_stride); + printf("o_row_stride = %ld\n", params.o_row_stride); + printf("start_m_block = %d\n", cur_m_block); + printf("start_tile_gid = %d\n", start_tile_gid); + printf("start_tile_hid = %d\n", start_tile_hid); + printf("cur_bidb = %d/%d\n", cur_bidb, params.b); + printf("cur_bidh = %d/%d\n", cur_bidh, params.h); + printf("num_m_blocks_per_head = %d\n", num_m_blocks_per_head); + printf("cur_m_block = %d\n", cur_m_block); + printf("num_tiles_in_block = %d\n", num_tiles_in_block); + printf("Total tiles = %d\n", num_tiles); + } +#endif + + // Prologue + int n_tile_min = kBlockM > kBlockN ? start_tile_hid - (block_scale * cur_m_block * (cur_m_block + 1) / 2) + : start_tile_hid - (int)(((int)floorf(cur_m_block * block_scale) * ((int)floorf(cur_m_block * block_scale) + 1) / 2) / block_scale) - ((cur_m_block % int(1 / block_scale)) * (floorf(cur_m_block * block_scale) + 1)); + int n_tile = n_tile_min + num_tiles_left - 1 >= num_tiles_in_block ? num_tiles_in_block - 1 : n_tile_min + num_tiles_left - 1; + + index_t row_offset_q = cur_bidb * params.q_batch_stride + + +cur_m_block * kBlockM * params.q_row_stride + cur_bidh * params.q_head_stride; + index_t row_offset_k = cur_bidb * params.k_batch_stride + + +n_tile * kBlockN * params.k_row_stride + (cur_bidh / params.h_h_k_ratio) * params.k_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + + // PREDICATES + // + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // // Start from the last block of first head + // lean::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + // params.seqlen_q - cur_m_block * kBlockM); + + // // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + // lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + // params.seqlen_k - n_tile * kBlockN); + // cute::cp_async_fence(); + + index_t row_offset_v = cur_bidb * params.v_batch_stride + + +n_tile * kBlockN * params.v_row_stride + (cur_bidh / params.h_h_k_ratio) * params.v_head_stride; + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + // Tiled Matrix Multiply + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling - Can be moved + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("n_tile_min = %d\n", n_tile_min); + printf("n_tile = %d\n", n_tile); + printf("row_offset_q = %" PRId64 "\n", row_offset_q); + printf("row_offset_k = %" PRId64 "\n", row_offset_k); + printf("row_offset_v = %" PRId64 "\n", row_offset_v); + } + + int num_blocks = 0; +#endif + + for (; num_tiles_left > 0;) { +#if defined(DEBUG_LEAN_ATTENTION) + num_blocks += 1; + auto prologue_start = clock64(); +#endif + + cur_bidb = start_tile_gid / (num_tiles_per_head * params.h); + cur_bidh = (start_tile_gid - (cur_bidb * num_tiles_per_head * params.h)) / num_tiles_per_head; + // Scheduling Policy - below + + // Calculate split ID + int block_start_gid = start_tile_gid - n_tile_min; + int cta_id_block_start = block_start_gid > params.high_load_tbs * params.max_tiles_per_tb + ? params.high_load_tbs + ((block_start_gid - (params.high_load_tbs * params.max_tiles_per_tb)) / (params.max_tiles_per_tb - 1)) + : block_start_gid / params.max_tiles_per_tb; + int n_split_idx = cta_id - cta_id_block_start; + + // Check host/ + int host_cta = 0; + int total_splits = 1; + if (n_tile_min == 0) { + host_cta = 1; + int block_end_gid = start_tile_gid + num_tiles_in_block - 1; + int cta_id_block_end = block_end_gid > params.high_load_tbs * params.max_tiles_per_tb + ? params.high_load_tbs + ((block_end_gid - (params.high_load_tbs * params.max_tiles_per_tb)) / (params.max_tiles_per_tb - 1)) + : block_end_gid / params.max_tiles_per_tb; + total_splits = cta_id_block_end - cta_id + 1; + } + + int end_cta = 0; + if (n_tile == num_tiles_in_block - 1) { + end_cta = 1; + } + + start_tile_gid += n_tile - n_tile_min + 1; + start_tile_hid += n_tile - n_tile_min + 1; + if (start_tile_hid >= num_tiles_per_head) { + // Next head + start_tile_hid = 0; + } + num_tiles_left -= n_tile - n_tile_min + 1; + + const BlockInfo binfo(params, cur_bidb); + // This is a hack, we really need to handle this outside the kernel + // But can't figure out a way to get actual seqlen_k in host-side code. + int max_actual_tiles = (binfo.actual_seqlen_k + kBlockN - 1) / kBlockN; + int num_actual_tiles_in_block = Is_causal ? std::max(max_actual_tiles, (int)ceilf(block_scale * (cur_m_block + 1))) : max_actual_tiles; + if (n_tile >= max_actual_tiles) { + tKgK.data() = tKgK.data() + (-int((n_tile - max_actual_tiles - 1) * kBlockN * params.k_row_stride)); + tVgV.data() = tVgV.data() + (-int((n_tile - max_actual_tiles - 1) * kBlockN * params.v_row_stride)); + n_tile = max_actual_tiles - 1; + } + if constexpr (Append_KV) { + if (end_cta) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, cur_bidb) + (n_tile * kBlockN) * params.knew_row_stride + (cur_bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, cur_bidb) + (n_tile * kBlockN) * params.vnew_row_stride + (cur_bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); + } +#endif + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_tile_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && (blockIdx.z == tracing_block || blockIdx.z == tracing_block + 1)) { + printf("Block %d n_tile_min %d n_tile %d n_block_copy_min %d\n", blockIdx.z, n_tile_min, n_tile, n_block_copy_min); + } +#endif + for (int n_block = n_tile; n_block >= n_block_copy_min; n_block--) { + lean::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + + lean::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; + } + } + lean::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - cur_m_block * kBlockM); + lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_tile * kBlockN); + cute::cp_async_fence(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("##### CTA : %d\n", blockIdx.z); + printf("cur_bidb = %d/%d\n", cur_bidb, params.b); + printf("cur_bidh = %d/%d\n", cur_bidh, params.h); + printf("cur_m_block = %d\n", cur_m_block); + printf("seqlen_k_cache = %d\n", binfo.seqlen_k_cache); + printf("actual_seqlen_q = %d\n", binfo.actual_seqlen_q); + printf("actual_seqlen_k = %d\n", binfo.actual_seqlen_k); + printf("num_tiles_in_block = %d\n", num_tiles_in_block); + printf("n_tile(new) = %d\n", n_tile); + printf("n_tile_min = %d\n", n_tile_min); + printf("host_cta = %d\n", host_cta); + printf("end_cta = %d\n", end_cta); + printf("n_split_idx = %d\n", n_split_idx); + printf("total_splits = %d\n", total_splits); + printf("\n#### For next block:\n"); + printf("start_tile_gid = %d\n", start_tile_gid); + printf("start_tile_hid = %d\n", start_tile_hid); + printf("num_tiles_left = %d\n", num_tiles_left); + printf("\n"); + } +#endif + + // All scheduling policy decisions should be made above this line + clear(acc_o); + + lean::Softmax<2 * size<1>(acc_o)> softmax; + + lean::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, 0.0f); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + lean::cp_async_wait<0>(); + __syncthreads(); + +#if defined(DEBUG_LEAN_ATTENTION) + prologue_duration += clock64() - prologue_start; + auto compute_start = clock64(); +#endif + + // Clear the smem tiles to account for predicated off loads + lean::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_tile * kBlockN); + cute::cp_async_fence(); + + lean::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Tile 0 - Svalue: acc_s[0] = %f\n", acc_s(0)); + } +#endif + + mask.template apply_mask( + acc_s, n_tile * kBlockN, cur_m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); + + lean::cp_async_wait<0>(); + __syncthreads(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + print(tVsV); + } + // __syncthreads(); +#endif + + if (n_tile > n_tile_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Tile 0 - PValue[0] = %f\n", acc_s(0)); + } +#endif + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = lean::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), lean::convert_layout_acc_Aregs(rP.layout())); + + lean::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Tile 0 - AfterPV[0] = %f\n", acc_o(0)); + } +#endif + + n_tile -= 1; + +#if defined(DEBUG_LEAN_ATTENTION) + comp1_duration += clock64() - compute_start; + compute_start = clock64(); +#endif + + // These are the iterations where we don't need masking on S + for (; n_tile >= n_tile_min; --n_tile) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + lean::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + + lean::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + lean::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ntile %d Svalue: acc_s[0] = %f\n", n_tile, acc_s(0)); + } +#endif + + lean::cp_async_wait<0>(); + __syncthreads(); + if (n_tile > n_tile_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask( + acc_s, n_tile * kBlockN, cur_m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ntile %d Pvalue: acc_s[0] = %f\n", n_tile, acc_s(0)); + } +#endif + Tensor rP = lean::convert_type(acc_s); + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), lean::convert_layout_acc_Aregs(rP.layout())); + + lean::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ntile %d AfterPV[0] = %f\n", n_tile, acc_o(0)); + } +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + // Epilogue + comp2_duration += clock64() - compute_start; + auto epilogue_start = clock64(); +#endif + + if (host_cta && end_cta) { +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("acc_o[0] = %f\n", acc_o(0)); + } +#endif + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("lse[0] = %f\n", lse(0)); + printf("acc_o[0] = %f\n", acc_o(0)); + } +#endif + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = lean::convert_type(acc_o); + + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = cur_bidb * params.o_batch_stride + + cur_m_block * kBlockM * params.o_row_stride + cur_bidh * params.o_head_stride; + + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("tOpO[0] = %d\n", tOpO(0)); + printf("tOrO[0] = %f\n", tOrO(0)); + } +#endif + // Clear_OOB_K must be false since we don't want to write zeros to gmem + lean::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, params.seqlen_q - cur_m_block * kBlockM); + // epil1_duration += clock64() - epilogue_start; + } else if (!host_cta) { + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = typename Kernel_traits::SmemCopyAtomOaccum; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = lean::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + __syncthreads(); + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_oaccum = (((index_t)(n_split_idx * params.b + cur_bidb) * params.h + cur_bidh) * params.seqlen_q + cur_m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + cur_bidb) * params.h + cur_bidh) * params.seqlen_q + cur_m_block * kBlockM; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("n_split_idx = %d\n", n_split_idx); + // printf("row_offset_o = %" PRId64 "\n", row_offset_o); + printf("row_offset_oaccum = %" PRId64 "\n", row_offset_oaccum); + printf("row_offset_lseaccum = %" PRId64 "\n", row_offset_lseaccum); + } +#endif + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + (row_offset_oaccum)), + Shape, Int>{}, + make_stride(kHeadDim, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + // This partitioning is unequal because only threads 0,4,8,etc write to gLSE + // and the rest are unused. + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < params.seqlen_q - cur_m_block * kBlockM) { + gLSEaccum(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + lean::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - cur_m_block * kBlockM); + + __threadfence(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && (blockIdx.z == tracing_block || blockIdx.z == tracing_block + 1)) { + printf("Block %d Writing Flag %d\n", blockIdx.z, (cur_bidb * params.h * num_m_blocks_per_head) + (cur_bidh * num_m_blocks_per_head) + cur_m_block); + } +#endif + + atomicAdd(reinterpret_cast(params.sync_flag) + (cur_bidb * params.h * num_m_blocks_per_head) + (cur_bidh * num_m_blocks_per_head) + cur_m_block, 1); + +#if defined(DEBUG_LEAN_ATTENTION) + epil2_duration += clock64() - epilogue_start; +#endif + } else { + constexpr int kNThreads = Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + //////////////////////////////////////////////////////////////////////////////// +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Before LSE acc_o[0] = %f\n", acc_o(0)); + } +#endif + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("After LSE acc_o[0] = %f\n", acc_o(0)); + printf("lse[0] = %f\n", lse(0)); + } +#endif + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = typename Kernel_traits::SmemCopyAtomOaccum; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = lean::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + __syncthreads(); + + // We move to SMEM and back because we need equal distribution of + // accum registers. Initially only threads 0,4,8,etc have oaccum values. + // So, first move them to SMEM. + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_oaccum = ((cur_bidb * params.h + cur_bidh) * (index_t)params.seqlen_q + cur_m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = (cur_bidb * params.h + cur_bidh) * (index_t)params.seqlen_q + cur_m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + (row_offset_oaccum)), + Shape, Int>{}, + make_stride(kHeadDim, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Block %d row_offset_oaccum = %" PRId64 "\n", blockIdx.z, row_offset_oaccum); + printf("Block %d row_offset_lseaccum = %" PRId64 "\n", blockIdx.z, row_offset_lseaccum); + } +#endif + + // GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + // auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + // Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + // Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOgOaccumReg = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccumReg)); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("First split t0g0accum.data() %p\n", tOgOaccum.data()); + } +#endif + + __syncthreads(); + + // Bring the oaccum back from SMEM to registers + // Now all threads have oaccum values equaly distributed. + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + ///////////////////////////////////////////////////////////////////////////// + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + Tensor sLSE = make_tensor(sV.data(), Shape, Int>{}); // (SMEM_M,SMEM_N) + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + + // This partitioning is unequal because only threads 0,4,8,etc write to gLSE + // and the rest are unused. + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int col = get<0>(taccOcO_row(mi)); + if (col < params.seqlen_q - cur_m_block * kBlockM) { + sLSE(0, col) = lse(mi); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("threadIdx.x %d col %d mi%d slSE %f\n", threadIdx.x, col, mi, lse(mi)); + } +#endif + } + } + } + + // Synchronize here to make sure all atomics are visible to all threads. + // Not exactly sure why we need this, but it seems to be necessary. + __threadfence(); + while (atomicAdd(reinterpret_cast(params.sync_flag) + + (cur_bidb * params.h * num_m_blocks_per_head) + + (cur_bidh * num_m_blocks_per_head) + cur_m_block, + 0) < (total_splits - 1) * kNThreads) { + __threadfence(); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x % 32 == 0 && blockIdx.z == tracing_block) { + printf("Waiting Block: %d target-value: %d\n", blockIdx.z, (total_splits - 1) * kNThreads); + } +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + // Print sync flag value + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + int32_t sync_flag = atomicAdd(reinterpret_cast(params.sync_flag) + + (cur_bidb * params.h * num_m_blocks_per_head) + + (cur_bidh * num_m_blocks_per_head) + cur_m_block, + 0); + if (threadIdx.x % 32 == 0 && blockIdx.z == tracing_block) { + printf("Sync flag value: %d\n", sync_flag); + } + } +#endif + + Tensor gLSEaccumRead = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape, Int>{}, + make_stride(params.b * params.h * params.seqlen_q, _1{})); + // Read the LSE values from gmem and store them in shared memory, then tranpose them. + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // R + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; // R + +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + // We skip the first row = 0, as we already populated it in shared memory. + ElementAccum lse = (row > 0 && row < total_splits && col < params.b * params.h * (index_t)params.seqlen_q - row_offset_lseaccum) ? gLSEaccumRead(row, col) : -INFINITY; + if (row > 0 && row < kMaxSplits) { + sLSE(row, col) = lse; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x % 32 == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d l %d row %d col %d lse %f\n", threadIdx.x, l, row, col, lse); + } +#endif + } + } + __syncthreads(); // For all LSEs to reach shared memory + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("kNLsePerThread %d kRowsPerLoadLSE %d kRowsPerLoadTranspose %d\n", kNLsePerThread, kRowsPerLoadLSE, kRowsPerLoadTranspose); + } +#endif + + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row, col) : -INFINITY; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d l %d row %d col %d lse_accum %f\n", threadIdx.x, l, row, col, lse_accum(l)); + } +#endif + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_max = max(lse_max, lse_accum(l)); + } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_sum += expf(lse_accum(l) - lse_max); + } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; +// if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } +// Store the scales exp(lse - lse_logsum) in shared memory. +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < total_splits && col < kBlockM) { + sLSE(row, col) = expf(lse_accum(l) - lse_logsum); + ElementAccum lse_scale = sLSE(row, col); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d l %d row %d col %d lse_accum %f lse_logsum %f sLSE %f\n", threadIdx.x, l, row, col, lse_accum(l), lse_logsum, lse_scale); + } +#endif + } + } + + Tensor tOrO = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { + tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; + } + } + + // Sync here for sLSE stores to go through + __syncthreads(); +// First reduce self Oaccum +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE(0, row); +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d Split %d m %d Row %d k %d i %d LSE %f Oaccum %f O %f\n", threadIdx.x, 0, m, row, k, i, lse_scale, tOrOaccum(i, m, k), tOrO(i, m, k)); + } +#endif + } + } + } + + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * (index_t)params.seqlen_q * params.d_rounded; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("After First Split t0g0accum.data() %p\n", tOgOaccum.data()); + } +#endif + // Load Oaccum in then scale and accumulate to O + // Here m is each row of 0accum along token dimension + // k is + for (int split = 1; split < total_splits; ++split) { + lean::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * (index_t)params.seqlen_q - row_offset_lseaccum); +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE(split, row); +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d Split %d m %d Row %d k %d i %d LSE %f Oaccum %f O %f\n", threadIdx.x, split, m, row, k, i, lse_scale, tOrOaccum(i, m, k), tOrO(i, m, k)); + } +#endif + } + } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * (index_t)params.seqlen_q * params.d_rounded; + } + + Tensor r1 = lean::convert_type(tOrO); + +// Write to gO +#pragma unroll + for (int m = 0; m < size<1>(r1); ++m) { + const int idx = cur_m_block * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.seqlen_q) { + // The index to the rows of Q + const int row = idx; + auto o_ptr = reinterpret_cast(params.o_ptr) + cur_bidb * params.o_batch_stride + cur_bidh * params.o_head_stride + row * params.o_row_stride; +#pragma unroll + for (int k = 0; k < size<2>(r1); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(r1))::value>>{}, Stride<_1>{}); + copy(r1(_, m, k), gO); + } + } + } + } +#if defined(DEBUG_LEAN_ATTENTION) + epil3_duration += clock64() - epilogue_start; +#endif + } + + if (num_tiles_left) { + // We can probably do better than this + // We first decrement the pointers back to starting. + // We can probably just use q_ptr and k_ptr directly. But can't figure out how to do it. + // Without disturbing the gQ, gK, gV tensor pointer CUTE objects. + tQgQ.data() = tQgQ.data() + (-int(row_offset_q)); + tKgK.data() = tKgK.data() + (((num_tiles_in_block - n_tile_min - 1) * kBlockN) * params.k_row_stride - row_offset_k); + tVgV.data() = tVgV.data() + (((num_tiles_in_block - n_tile_min - 1) * kBlockN) * params.v_row_stride - row_offset_v); + cur_m_block = cur_m_block + 1 >= num_m_blocks_per_head ? 0 : cur_m_block + 1; + num_tiles_in_block = Is_causal ? (int)ceilf(block_scale * (cur_m_block + 1)) : num_tiles_per_head; + n_tile = num_tiles_left - 1 >= num_tiles_in_block ? num_tiles_in_block - 1 : num_tiles_left - 1; + n_tile_min = 0; + cur_bidb = start_tile_gid / (num_tiles_per_head * params.h); + cur_bidh = (start_tile_gid - (cur_bidb * num_tiles_per_head * params.h)) / num_tiles_per_head; + + row_offset_q = cur_bidb * params.q_batch_stride + + +cur_m_block * kBlockM * params.q_row_stride + cur_bidh * params.q_head_stride; + row_offset_k = cur_bidb * params.k_batch_stride + + +n_tile * kBlockN * params.k_row_stride + (cur_bidh / params.h_h_k_ratio) * params.k_head_stride; + row_offset_v = cur_bidb * params.v_batch_stride + + +n_tile * kBlockN * params.v_row_stride + (cur_bidh / params.h_h_k_ratio) * params.v_head_stride; + + tQgQ.data() = tQgQ.data() + row_offset_q; + tKgK.data() = tKgK.data() + row_offset_k; + tVgV.data() = tVgV.data() + row_offset_v; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("#### Ready for next block:\n"); + printf("next_block %d\n", cur_m_block); + printf("n_tile %d\n", n_tile); + printf("row_offset_q = %" PRId64 "\n", row_offset_q); + printf("row_offset_k = %" PRId64 "\n", row_offset_k); + printf("row_offset_v = %" PRId64 "\n", row_offset_v); + } +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + epilogue_duration += clock64() - epilogue_start; +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0) { + uint smid; + asm("mov.u32 %0, %smid;" : "=r"(smid)); + printf("%d %d %d %d %lld %lld %lld %lld %lld %lld %lld %lld\n", + blockIdx.z, num_blocks, smid, cta_id, clock64() - kernel_start, prologue_duration, comp1_duration, + comp2_duration, epilogue_duration, epil1_duration, epil2_duration, epil3_duration); + } +#endif +} + +template +inline __device__ void lean_compute_attn(const Params& params) { + // const int cta_id = blockIdx.z < 54 ? 4*blockIdx.z : blockIdx.z < 108 ? 4*(blockIdx.z % 54) + 2 : blockIdx.z < 162 ? 4*(blockIdx.z % 108) + 1 : 4*(blockIdx.z % 162) + 3; + const int cta_id = blockIdx.z; + int start_tile_gid = cta_id < params.high_load_tbs ? params.max_tiles_per_tb * cta_id : (params.max_tiles_per_tb - 1) * cta_id + params.high_load_tbs; + int start_tile_hid = start_tile_gid % params.tiles_per_head; + int num_tiles = cta_id < params.high_load_tbs ? params.max_tiles_per_tb : params.max_tiles_per_tb - 1; + + lean::lean_compute_attn_impl_ver3(params, cta_id, start_tile_gid, start_tile_hid, num_tiles, params.tiles_per_head); +} + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h new file mode 100644 index 0000000000000..fcccb54ebf4e8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h @@ -0,0 +1,73 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "contrib_ops/cuda/bert/lean_attention/static_switch.h" +#include "contrib_ops/cuda/bert/lean_attention/flash.h" +#include "contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h" + +namespace onnxruntime { +namespace lean { + +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ + template \ + __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(lean_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, int kMaxSplits, bool Append_KV) { +#if defined(ARCH_SUPPORTS_FLASH) + lean::lean_compute_attn(params); +#else + FLASH_UNSUPPORTED_ARCH +#endif +} + +template +void run_lean_fwd(Flash_fwd_params& params, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + dim3 grid(1, 1, params.lean_griddimz); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + MAXSPLIT_SWITCH(params.num_splits, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV_Const, [&] { + auto kernel = &lean_fwd_kernel < Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, kMaxSplits, Append_KV_Const > ; + if (2 * smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 2 * smem_size); + } + kernel<<>>(params); + }); + }); + }); + }); + }); +} + +template +void run_mha_fwd_lean_dispatch(Flash_fwd_params& params, cudaStream_t stream) { + // This should be modified according to optimal lean tile size + constexpr static int kBlockM = Headdim <= 64 ? 64 : (Headdim <= 128 ? 64 : 64); + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_lean_fwd>(params, stream); +} + +} // namespace lean +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h new file mode 100644 index 0000000000000..d63c80b012de6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h @@ -0,0 +1,209 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +namespace onnxruntime { +namespace lean { + +using namespace cute; + +template +__forceinline__ __device__ void apply_mask(Tensor& tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { +// Without the "make_coord" we get wrong results +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +__forceinline__ __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +__forceinline__ __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +__forceinline__ __device__ void apply_mask_causal_w_idx( + Tensor& tensor, Tensor const& idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); +#pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +struct Mask { + const int max_seqlen_k, max_seqlen_q; + const int window_size_left, window_size_right; + const float alibi_slope; + + __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, + const int window_size_left, const int window_size_right, + const float alibi_slope = 0.f) + : max_seqlen_k(max_seqlen_k), max_seqlen_q(max_seqlen_q), window_size_left(window_size_left), window_size_right(window_size_right), alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { + }; + + // Causal_mask: whether this particular iteration needs causal masking + template + __forceinline__ __device__ void apply_mask(Tensor& tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), lean::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } else { +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Has_alibi) { + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; +}; + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h new file mode 100644 index 0000000000000..ad66389848e6e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h @@ -0,0 +1,196 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +#include "contrib_ops/cuda/bert/lean_attention/utils.h" + +namespace onnxruntime { +namespace lean { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor& dst, Tensor& src, Operator& op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor& max) { + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor& sum) { + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor& tensor, Tensor const& max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { +// Instead of computing exp(x - max), we compute exp2(x * log_2(e) - +// max * log_2(e)) This allows the compiler to use the ffma +// instruction instead of fadd and fmul separately. +// The following macro will disable the use of fma. +// See: https://github.com/pytorch/pytorch/issues/121558 for more details +// This macro is set in PyTorch and not FlashAttention +#ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); +#else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); +#endif + } + } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ void softmax_rescale_o(Tensor0& acc_s, Tensor1& acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), lean::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + lean::template reduce_max(scores, row_max); + lean::scale_apply_exp2(scores, row_max, softmax_scale_log2); + lean::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + lean::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), lean::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale; + } + } + lean::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + lean::reduce_sum(scores, row_sum); + } + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, float rp_dropout = 1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), lean::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); +#pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + // if (threadIdx.x == 0 && blockIdx.z == 0) { + // printf("sum: %f, inv_sum: %f\n", sum, inv_sum); + // printf("mi %d row_max %f softmax_scale %f\n", mi, row_max(mi), softmax_scale); + // } + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + return lse; + }; +}; + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/static_switch.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/static_switch.h new file mode 100644 index 0000000000000..7873f67471d5d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/static_switch.h @@ -0,0 +1,109 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#ifdef FLASHATTENTION_DISABLE_DROPOUT +#define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_ALIBI +#define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K +#define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else +#define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL +#define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define LOCAL_SWITCH BOOL_SWITCH +#endif + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } \ + }() + +#define MAXSPLIT_SWITCH(MAXSPLITS, ...) \ + [&] { \ + if (MAXSPLITS <= 2) { \ + constexpr static int kMaxSplits = 2; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 4) { \ + constexpr static int kMaxSplits = 4; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 8) { \ + constexpr static int kMaxSplits = 8; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 16) { \ + constexpr static int kMaxSplits = 16; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 32) { \ + constexpr static int kMaxSplits = 32; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 64) { \ + constexpr static int kMaxSplits = 64; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/utils.h new file mode 100644 index 0000000000000..c76849686d539 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/utils.h @@ -0,0 +1,411 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace onnxruntime { +namespace lean { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ uint32_t relu2(const uint32_t x); + +template <> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +__forceinline__ __device__ uint32_t convert_relu2(const float2 x); + +template <> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +template <> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ __forceinline__ T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +template <> +struct Allreduce<1> { + template + static __device__ __forceinline__ T run(T x, Operator& op) { + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm_rs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +template +__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void relu_(Tensor& tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); +#pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +__forceinline__ __device__ auto convert_type_relu(Tensor const& tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); +#pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = lean::convert_type(tensor); + lean::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, const int max_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } + // TD [2023-04-13]: Strange that the code below can cause race condition. + // I think it's because the copies are under an if statement. + // if (Is_even_K) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, _), D(_, m, _)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, _)); + // } + // } + // } else { // It's slightly faster in this case if iterate over K first + // #pragma unroll + // for (int k = 0; k < size<2>(S); ++k) { + // if (predicate_K(k)) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, k), D(_, m, k)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, k)); + // } + // } + // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN + // if (Clear_OOB_MN || Is_even_MN) { + // clear(D(_, _, k)); + // } else { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { + // clear(D(_, m, k)); + // } + // } + // } + // } + // } + // } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_w_min_idx(Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 9c558900d1fdb..e2587d172af94 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -9,6 +9,7 @@ #include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/lean_attention/lean_api.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -54,6 +55,10 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); +#if USE_LEAN_ATTENTION + enable_lean_attention_ = sizeof(T) == 2 && kernel_options_->UseLeanAttention(); +#endif + disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention(); @@ -151,8 +156,64 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default; + typedef typename ToCudaType::MappedType CudaT; + AttentionData data; + +#if USE_LEAN_ATTENTION || USE_FLASH_ATTENTION + size_t softmax_lse_bytes = 0; + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; +#endif + +#if USE_LEAN_ATTENTION + // Lean attention only supports token-generation phase with sequence_length == 1. + bool use_lean_attention = enable_lean_attention_ && + parameters.sequence_length == 1 && + parameters.past_sequence_length > 0 && + nullptr == attention_bias && + nullptr == key_padding_mask && + parameters.head_size == parameters.v_head_size && + onnxruntime::lean::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.num_heads); + + size_t sync_flag_bytes = 0; + if (use_lean_attention) { + softmax_lse_bytes = onnxruntime::lean::get_softmax_lse_size(parameters.sequence_length, + parameters.batch_size, + parameters.num_heads); + + auto [num_splits, slse_accum_bytes, o_accum_bytes, sflag_bytes, griddimz, max_tiles_tb, hload_tbs, tiles_per_head] = onnxruntime::lean::get_num_splits_and_buffer_sizes( + parameters.batch_size, + parameters.sequence_length, + parameters.total_sequence_length, + parameters.num_heads, // q heads + parameters.num_heads, // kv heads + parameters.head_size, + device_prop.multiProcessorCount, + parameters.is_unidirectional); + + data.num_splits = static_cast(num_splits); + data.grid_dim_z = static_cast(griddimz); + data.max_tiles_per_tb = static_cast(max_tiles_tb); + data.high_load_tbs = static_cast(hload_tbs); + data.tiles_per_head = static_cast(tiles_per_head); + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + sync_flag_bytes = sflag_bytes; + kernel_type = AttentionKernelType::AttentionKernel_LeanAttention; + } + + auto lean_sync_flag_buffer = GetScratchBuffer(sync_flag_bytes, context->GetComputeStream()); + data.lean_sync_flag = reinterpret_cast(lean_sync_flag_buffer.get()); +#else + constexpr bool use_lean_attention = false; +#endif + #if USE_FLASH_ATTENTION - bool use_flash_attention = !disable_flash_attention_ && + bool use_flash_attention = kernel_type == AttentionKernelType::AttentionKernel_Default && + !disable_flash_attention_ && nullptr == attention_bias && nullptr == key_padding_mask && parameters.head_size == parameters.v_head_size && @@ -165,25 +226,35 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } + // Allocate buffers - size_t softmax_lse_accum_bytes = 0; - size_t out_accum_bytes = 0; if (use_flash_attention) { + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, + parameters.batch_size, + parameters.num_heads); + using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); - parameters.num_splits = static_cast(num_splits); + data.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; kernel_type = AttentionKernelType::AttentionKernel_FlashAttention; } - auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; - auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr - auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr +#endif + +#if USE_LEAN_ATTENTION || USE_FLASH_ATTENTION + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + if (use_flash_attention || use_lean_attention) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } #endif bool is_mask_none_or_1d_k_len = parameters.mask_type == AttentionMaskType::MASK_NONE || @@ -284,8 +355,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { kernel_type = AttentionKernelType::AttentionKernel_Unfused; } - typedef typename ToCudaType::MappedType CudaT; - AttentionData data; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); data.key = (nullptr == key) ? nullptr : reinterpret_cast(key->Data()); @@ -303,6 +372,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_flash_attention = use_flash_attention; + data.use_lean_attention = use_lean_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; data.kernel_type = kernel_type; data.allocator = Info().GetAllocator(OrtMemType::OrtMemTypeDefault); @@ -331,6 +401,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.total_sequence_length, fused_runner, use_flash_attention, + use_lean_attention, use_fused_cross_attention, use_memory_efficient_attention, use_cudnn_sdpa, @@ -342,16 +413,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.workspace_bytes = workspace_bytes; data.allow_debug_info = kernel_options_->AllowDebugInfo(); - if (softmax_lse_accum_buffer != nullptr) { - data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); - } - if (out_accum_buffer != nullptr) { - data.out_accum = reinterpret_cast(out_accum_buffer.get()); - } if (data.allow_debug_info) { AttentionKernelDebugInfo debug_info; debug_info.use_flash_attention = use_flash_attention; + debug_info.use_lean_attention = use_lean_attention; debug_info.use_cudnn_flash_attention = use_cudnn_sdpa; debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; debug_info.use_efficient_attention = use_memory_efficient_attention; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 8edc1d0e6ac06..b093b226c50b0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -32,6 +32,9 @@ class MultiHeadAttention final : public CudaKernel { bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; bool disable_flash_attention_; +#if USE_LEAN_ATTENTION + bool enable_lean_attention_; +#endif bool disable_memory_efficient_attention_; bool enable_cudnn_flash_attention_; diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 1b774b163888f..33cd906508bcf 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -179,6 +179,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_fused_cross_attention = false; constexpr bool use_memory_efficient_attention = false; constexpr bool use_flash_attention = false; + constexpr bool use_lean_attention = false; constexpr bool use_cudnn_flash_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, @@ -190,6 +191,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { parameters.total_sequence_length, fused_runner, use_flash_attention, + use_lean_attention, use_fused_cross_attention, use_memory_efficient_attention, use_cudnn_flash_attention, diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index d8acb66158ed2..d922f153b4b91 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -72,6 +72,7 @@ class SdpaKernel(IntEnum): TRT_FLASH_ATTENTION = 32 TRT_CROSS_ATTENTION = 64 TRT_CAUSAL_ATTENTION = 128 + LEAN_ATTENTION = 256 # Since we support attention bias, so we only need support up to 2D mask. @@ -598,8 +599,8 @@ def measure_latency(cuda_session: CudaSession, input_dict): return end - start -def flops(batch, sequence_length, head_size, num_heads, causal): - return 4 * batch * sequence_length**2 * num_heads * head_size // (2 if causal else 1) +def flops(batch, sequence_length_q, sequence_length_kv, head_size, num_heads, causal): + return 4 * batch * sequence_length_q * sequence_length_kv * num_heads * head_size // (2 if causal else 1) def tflops_per_second(flop, time): @@ -613,6 +614,7 @@ def get_gpu_kernel_name(attention_kernel: SdpaKernel) -> str: kernel_names = { SdpaKernel.DEFAULT: "ort:default", SdpaKernel.FLASH_ATTENTION: "ort:flash", + SdpaKernel.LEAN_ATTENTION: "ort:lean", SdpaKernel.EFFICIENT_ATTENTION: "ort:efficient", SdpaKernel.CUDNN_FLASH_ATTENTION: "ort:cudnn", SdpaKernel.MATH: "ort:math", @@ -808,16 +810,17 @@ def sdpa_kernel_from_debug_info( ): os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "1" captured_text = None + try: with CaptureStdout() as captured: session = create_session(config, sess_options, attention_kernel=attention_kernel) input_dict = config.random_inputs() session.infer(input_dict) - captured_text = captured.output.decode() + captured_text = captured.output.decode() except Exception as e: print(f"Failed to run {attention_kernel=} for {config=}. Exception: {e}") - finally: - os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0" + + os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0" if captured_text is not None: m = re.search("SdpaKernel=(?P[A-Z_]+)", captured_text) @@ -825,6 +828,7 @@ def sdpa_kernel_from_debug_info( name = m.group("kernel") kernel_names = { "FLASH_ATTENTION": "ort:flash", + "LEAN_ATTENTION": "ort:lean", "EFFICIENT_ATTENTION": "ort:efficient", "CUDNN_FLASH_ATTENTION": "ort:cudnn", "MATH": "ort:math", @@ -867,6 +871,15 @@ def run_tflops_test( SdpaKernel.CUDNN_FLASH_ATTENTION, SdpaKernel.MATH, ] + + if args.past_sequence_length > 0: + backends.append(SdpaKernel.LEAN_ATTENTION) + + if args.past_sequence_length > 0 and causal: + backends.remove(SdpaKernel.CUDNN_FLASH_ATTENTION) + + if args.past_sequence_length > 4096: + backends.remove(SdpaKernel.MATH) else: backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH] else: @@ -884,6 +897,8 @@ def run_tflops_test( for input_format in formats: for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: + if past_sequence_length > 0 and input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: + continue config = MultiHeadAttentionConfig( batch_size=batch_size, sequence_length=sequence_length, @@ -900,6 +915,7 @@ def run_tflops_test( dtype=torch.float16 if use_gpu else torch.float, share_past_present_buffer=False, input_format=input_format, + has_past_input=past_sequence_length > 0, has_attn_bias=args.has_attn_bias, broadcast_attn_bias_dim_0=args.broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1=args.broadcast_attn_bias_dim_1, @@ -926,11 +942,19 @@ def run_tflops_test( print(f"skip input_format for {vars(config)}") continue + if use_gpu and config.total_sequence_length > 8192: + if config.verbose: + print(f"skip large sequence length for {vars(config)}") + continue + if use_gpu: actual_kernel = sdpa_kernel_from_debug_info(config, attention_kernel, sess_options) if actual_kernel is None: print(f"Warning: skip {config} since kernel from debug info is None") continue + if actual_kernel != request_kernel and request_kernel != "ort:default": + print(f"Skip since {actual_kernel=} != {request_kernel=}") + continue else: # CPU has no debug info for now. actual_kernel = request_kernel @@ -956,11 +980,17 @@ def run_tflops_test( format_str = InputFormats.input_format_str(input_format) # compute TFLOPS per second - speed = None - if past_sequence_length == 0: - speed = tflops_per_second( - flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency - ) + speed = tflops_per_second( + flops( + batch_size, + sequence_length, + sequence_length + past_sequence_length, + head_size, + num_heads, + causal, + ), + average_latency, + ) row = { "use_gpu": use_gpu, @@ -983,11 +1013,11 @@ def run_tflops_test( } csv_writer.writerow(row) - speed = f"{speed:.2f}" if speed is not None else "NA" + speed = f"{speed:.3f}" if speed is not None else "NA" print( f"{format_str}\t{causal}\t{args.has_attn_bias}\t{batch_size}\t" f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t" - f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{actual_kernel}\t{request_kernel}" + f"{intra_op_num_threads}\t{average_latency * 1000:.3f}\t{speed}\t{actual_kernel}\t{request_kernel}" ) @@ -1055,7 +1085,17 @@ def run_torch_test( except RuntimeError: continue - speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency) + speed = tflops_per_second( + flops( + batch_size, + sequence_length, + sequence_length + past_sequence_length, + head_size, + num_heads, + causal, + ), + torch_latency, + ) input_format = "Q,K,V" print( f"{input_format}\t{causal}\t{False}\t{batch_size}\t" @@ -1090,7 +1130,8 @@ def run_tflops_tests(args): features += "_causal" if args.past_sequence_length > 0: features += "_past" - csv_filename = "benchmark_mha_{}_{}_{}.csv".format( + csv_filename = "{}_{}_{}_{}.csv".format( + args.csv_filename_prefix, features, "torch" if args.torch else "ort", datetime.now().strftime("%Y%m%d-%H%M%S"), @@ -1343,6 +1384,14 @@ def _parse_arguments(): ) parser.set_defaults(broadcast_attn_bias_dim_1=False) + parser.add_argument( + "--csv_filename_prefix", + required=False, + type=str, + default="benchmark_mha", + help="Prefix of csv filename", + ) + args = parser.parse_args() return args diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh index ff6dd16e698df..8d811219d4dac 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.sh +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -5,45 +5,104 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" +# Usage: benchmark_mha.sh [gpu|cpu|lean] +task="${1:-gpu}" -export CUDA_VISIBLE_DEVICES=0 -python benchmark_mha.py --use_gpu +# Function to lock GPU clocks and set power limit for a GPU +configure_gpu() { + local gpu_id=$1 -echo "Benchmark BERT-Large performance on GPU without attention bias" -python benchmark_mha.py --use_gpu -b 16 + # Ensure nvidia-smi is available + if ! command -v nvidia-smi &> /dev/null + then + echo "nvidia-smi not found. Please ensure NVIDIA drivers are installed." + exit + fi -echo "Benchmark BERT-Large performance on GPU with attention bias" -python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias -python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 -python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 --broadcast_attn_bias_dim_1 + # Enable Persistence Mode + sudo nvidia-smi -pm 1 -i $gpu_id -python benchmark_mha.py --use_gpu --use_cuda_graph -python benchmark_mha.py --use_gpu --torch + # Get the maximum clock speeds for graphics and memory. + nvidia-smi -q -d CLOCK -i ${gpu_id} | grep -A3 "Max Clocks" + max_graphics_clock=$(nvidia-smi -q -d CLOCK -i ${gpu_id} | grep -A1 "Max Clocks" | grep "Graphics" | awk '{print $3}') + max_memory_clock=$(nvidia-smi -q -d CLOCK -i ${gpu_id} | grep -A3 "Max Clocks" | grep "Memory" | awk '{print $3}') -cat benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv + # Lock the GPU clocks to maximum frequencies + sudo nvidia-smi -i $gpu_id --lock-gpu-clocks=$max_graphics_clock,$max_graphics_clock + sudo nvidia-smi -i $gpu_id --lock-memory-clocks=$max_memory_clock,$max_memory_clock -echo "Benchmark performance on CPU with number of threads:" -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=1 python benchmark_mha.py --torch -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=2 python benchmark_mha.py --torch -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=4 python benchmark_mha.py --torch -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=8 python benchmark_mha.py --torch + nvidia-smi --query-gpu=clocks.gr,clocks.sm,clocks.mem --format=csv + echo "GPU $gpu_id clocks locked to $max_graphics_clock MHz (graphics) and $max_memory_clock MHz (memory)" -python benchmark_mha.py --intra_op_num_threads 1 -python benchmark_mha.py --intra_op_num_threads 2 -python benchmark_mha.py --intra_op_num_threads 4 -python benchmark_mha.py --intra_op_num_threads 8 + # Set Power Limit to maximum + power_limit=$(nvidia-smi --query-gpu=power.limit -i 0 --format=csv | grep "0" | awk '{print $1}') + power_limit=${power_limit%.*} + sudo nvidia-smi -pl $power_limit -i $gpu_id + export CUDA_VISIBLE_DEVICES=$gpu_id +} -echo "Benchmark performance on CPU with default threads settings:" -python benchmark_mha.py -ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py -python benchmark_mha.py --torch +run_gpu_benchmarks() { + echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" -python benchmark_mha.py --causal -python benchmark_mha.py --torch --causal + python benchmark_mha.py --use_gpu -# Pytorch SDPA does not support causal attention with past state, we only test ORT here. -python benchmark_mha.py --causal --has_past + echo "Benchmark BERT-Large performance on GPU without attention bias" + python benchmark_mha.py --use_gpu -b 16 -cat benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv + echo "Benchmark BERT-Large performance on GPU with attention bias" + python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias + python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 + python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 --broadcast_attn_bias_dim_1 + + python benchmark_mha.py --use_gpu --use_cuda_graph + python benchmark_mha.py --use_gpu --torch + + cat benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv +} + +run_lean_benchmarks() { + echo "Benchmark long context decoding performance on GPU" + for b in 1 4 16; do + for s in 32 64 128 256 512 1024 2048 4096 8192 16384 32768 65536; do + python benchmark_mha.py --use_gpu --causal -b $b -s 1 -p $s -n 16 -d 64 -r 1000 --csv_filename_prefix benchmark_lean + python benchmark_mha.py --use_gpu --causal -b $b -s 1 -p $s -n 32 -d 128 -r 1000 --csv_filename_prefix benchmark_lean + done + done + cat benchmark_lean_*.csv > lean_benchmark_results.csv +} + +run_cpu_benchmarks() { + echo "Benchmark performance on CPU with number of threads:" + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=1 python benchmark_mha.py --torch + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=2 python benchmark_mha.py --torch + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=4 python benchmark_mha.py --torch + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=8 python benchmark_mha.py --torch + + python benchmark_mha.py --intra_op_num_threads 1 + python benchmark_mha.py --intra_op_num_threads 2 + python benchmark_mha.py --intra_op_num_threads 4 + python benchmark_mha.py --intra_op_num_threads 8 + + + echo "Benchmark performance on CPU with default threads settings:" + python benchmark_mha.py + ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py + python benchmark_mha.py --torch + + python benchmark_mha.py --causal + python benchmark_mha.py --torch --causal + + # Pytorch SDPA does not support causal attention with past state, we only test ORT here. + python benchmark_mha.py --causal --has_past + + cat benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv +} + +[ "$task" != "cpu" ] && configure_gpu 0 + +[ "$task" == "gpu" ] && run_gpu_benchmarks + +[ "$task" == "cpu" ] && run_cpu_benchmarks + +[ "$task" == "lean" ] && run_lean_benchmarks diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 69f0035ef8a17..9e7c7378370c1 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -9,6 +9,7 @@ import concurrent.futures import itertools +import os import unittest from typing import Dict, List, Optional @@ -400,6 +401,49 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): yield config +def lean_attention_test_cases(provider: str, comprehensive: bool): + if provider == "CUDAExecutionProvider" and get_compute_capability() < 80: + return + yield + + batch_sizes = [1, 2, 3] if comprehensive else [1, 2] + sequence_lengths = [2, 15, 16, 255, 256, 512, 1024, 2048, 4096, 8192] if comprehensive else [2, 255, 512] + heads = [1, 4, 16] if comprehensive else [1, 4] + head_sizes = [64, 128] + device, dtype, formats = get_provider_support_info(provider, True) + mask_formats = [AttentionMaskFormat.Mask_None] + + sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory + for batch_size in batch_sizes: + for total_seq_len in sequence_lengths: + for num_heads in heads: + for head_size in head_sizes: + for format in formats: + for causal in get_causal_support(format): + for is_prompt in [False]: + for mask_format in mask_formats: + sequence_length = total_seq_len if is_prompt else 1 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=total_seq_len - sequence_length, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=True, + share_past_present_buffer=False, + input_format=format, + mask_format=mask_format, + ) + yield config + + def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): if provider == "CUDAExecutionProvider" and get_compute_capability() < 60: return @@ -787,6 +831,12 @@ def run_mha_cuda(self): for config in mha_test_cases("CUDAExecutionProvider", comprehensive_mode): parity_check_mha(config, rtol=5e-3, atol=5e-3) + def run_lean_attention(self): + os.environ["ORT_ENABLE_LEAN_ATTENTION"] = "1" + for config in lean_attention_test_cases("CUDAExecutionProvider", comprehensive_mode): + parity_check_mha(config, rtol=5e-3, atol=5e-3 if config.total_sequence_length <= 512 else 5e-2) + os.environ.pop("ORT_ENABLE_LEAN_ATTENTION", None) + def run_mha_cpu(self): for config in mha_test_cases("CPUExecutionProvider", comprehensive_mode): parity_check_mha(config, rtol=5e-3, atol=5e-3) @@ -842,6 +892,7 @@ def test_all(self): # Run tests sequentially to avoid out of memory issue. self.run_mha_cpu() self.run_mha_cuda() + self.run_lean_attention() self.run_mha_cuda_multi_threading_default() self.run_mha_cuda_multi_threading_cudnn() self.run_mha_cuda_multi_threading_efficient()