diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ef208f59f63b0..d90a2a355045e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -106,6 +106,7 @@ option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) +cmake_dependent_option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA; NOT WIN32" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) @@ -751,21 +752,30 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_DISABLE_CONTRIB_OPS) set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_LEAN_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_LEAN_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) elseif(WIN32 AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) message( STATUS "Flash-Attention unsupported in Windows with CUDA compiler version < 12.0") set(onnxruntime_USE_FLASH_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") endif() + if (WIN32) + message( STATUS "Lean Attention unsupported in Windows") + set(onnxruntime_USE_LEAN_ATTENTION OFF) + endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_LEAN_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() @@ -779,6 +789,13 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1) endif() + + if (onnxruntime_USE_LEAN_ATTENTION) + message( STATUS "Enable lean attention for CUDA EP") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_LEAN_ATTENTION=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_LEAN_ATTENTION=1) + endif() + if (onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) message( STATUS "Enable memory efficient attention for CUDA EP") list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 46638555576a9..97d6cc1ce7d66 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -48,6 +48,7 @@ enum AttentionKernelType { AttentionKernel_CutlassMemoryEfficientAttention, AttentionKernel_FlashAttention, AttentionKernel_CudnnFlashAttention, + AttentionKernel_LeanAttention, AttentionKernel_Default }; @@ -65,7 +66,6 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; - int num_splits; int rotary_embedding; bool is_unidirectional; bool past_present_share_buffer; @@ -208,10 +208,13 @@ enum class AttentionBackend : int { CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention. MATH = 16, // unfused kernel cannot be disabled right now. - // The following kernels might be deprecated in the future. + // The following TRT kernels might be deprecated in the future. TRT_FLASH_ATTENTION = 32, TRT_CROSS_ATTENTION = 64, TRT_CAUSAL_ATTENTION = 128, + + // Experimental kernels + LEAN_ATTENTION = 256, }; // Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled). @@ -239,6 +242,9 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF // Environment variable to enable or disable flash attention. Default is 0 (enabled). constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; +// Environment variable to enable or disable lean attention. Default is 0 (disabled). +constexpr const char* kEnableLeanAttention = "ORT_ENABLE_LEAN_ATTENTION"; + // Minimum sequence length to perfer memory efficient attention when data type is float32 constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32"; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index efbc0b5031657..22e2879a5be15 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -102,6 +102,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const int sm = device_prop.major * 10 + device_prop.minor; const bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + typedef typename ToCudaType::MappedType CudaT; + AttentionData data; + #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && (nullptr == attention_bias) && @@ -118,21 +121,26 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_flash_attention = false; } // Allocate buffers + size_t softmax_lse_bytes = 0; size_t softmax_lse_accum_bytes = 0; size_t out_accum_bytes = 0; if (use_flash_attention) { + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, parameters.num_heads); + using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); - parameters.num_splits = static_cast(num_splits); + data.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; } + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif @@ -247,6 +255,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { constexpr size_t element_size = sizeof(T); constexpr bool use_fused_cross_attention = false; constexpr bool use_cudnn_flash_attention = false; + constexpr bool use_lean_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, parameters.batch_size, parameters.num_heads, @@ -257,14 +266,13 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.total_sequence_length, fused_runner, use_flash_attention, + use_lean_attention, use_fused_cross_attention, use_memory_efficient_attention, use_cudnn_flash_attention, false); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); - typedef typename ToCudaType::MappedType CudaT; - AttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); if (nullptr != bias) { data.bias = reinterpret_cast(bias->Data()); @@ -289,6 +297,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + if (softmax_lse_buffer != nullptr) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + } + if (softmax_lse_accum_buffer != nullptr) { data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index eff58c0080012..9e017544d7cff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -39,6 +39,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/lean_attention/lean_api.h" #include "contrib_ops/cuda/bert/attention_impl.h" using namespace onnxruntime::cuda; @@ -108,6 +109,7 @@ size_t GetAttentionWorkspaceSize( size_t total_sequence_length, void* fused_runner, bool use_flash_attention, + bool use_lean_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention, bool use_cudnn_flash_attention, @@ -119,12 +121,20 @@ size_t GetAttentionWorkspaceSize( #if USE_FLASH_ATTENTION if (use_flash_attention) { - return qkv_bytes + onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads); + return qkv_bytes; } #else ORT_UNUSED_PARAMETER(use_flash_attention); #endif +#if USE_LEAN_ATTENTION + if (use_lean_attention) { + return qkv_bytes; + } +#else + ORT_UNUSED_PARAMETER(use_lean_attention); +#endif + #if USE_MEMORY_EFFICIENT_ATTENTION if (use_memory_efficient_attention) { size_t fmha_buffer_bytes = 0; @@ -301,10 +311,10 @@ Status FlashAttention( constexpr bool is_bf16 = false; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.scratch), + device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.softmax_lse), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, parameters.sequence_length, parameters.total_sequence_length, scale, 0.0, parameters.is_unidirectional, is_bf16, - false, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + false, data.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH)); return Status::OK(); @@ -326,6 +336,81 @@ Status FlashAttention( } #endif +#if USE_LEAN_ATTENTION +template +Status LeanAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + assert(nullptr == data.mask_index); + assert(nullptr == data.attention_bias); + assert(parameters.head_size == parameters.v_head_size); + + constexpr bool is_bf16 = false; + + ORT_RETURN_IF_ERROR(onnxruntime::lean::mha_fwd_kvcache( + device_prop, stream, + data.q, + data.k, // k_cache + data.v, // v_cache + nullptr, // new_k (we have appended new_k to k_cache) + nullptr, // new_v (we have appended new_v to k_cache) + data.output, + reinterpret_cast(data.softmax_lse), + nullptr, // seqlens_k + nullptr, // cos_cache + nullptr, // sin_cache + nullptr, // block_table + parameters.batch_size, + parameters.num_heads, + parameters.num_heads, // num_heads_k + parameters.head_size, + parameters.sequence_length, // seqlen_q + parameters.total_sequence_length, // seqlen_k + 0, // seqlen_k_new + 0, // rotary_dim + scale, // softmax_scale + parameters.is_unidirectional, + is_bf16, + false, // past_bsnh + data.num_splits, + data.grid_dim_z, + data.max_tiles_per_tb, + data.high_load_tbs, + data.tiles_per_head, + reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), + data.lean_sync_flag, + -1, // local_window_size + false, // is_rotary_interleaved + false // is_packed_qkv + )); + + return Status::OK(); +} + +template <> +Status LeanAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + ORT_UNUSED_PARAMETER(device_prop); + ORT_UNUSED_PARAMETER(stream); + ORT_UNUSED_PARAMETER(parameters); + ORT_UNUSED_PARAMETER(data); + ORT_UNUSED_PARAMETER(scale); + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "lean attention does not support float tensor"); +} +#endif + + + template Status CudnnFlashAttention( cudnnHandle_t cudnn_handle, @@ -641,6 +726,11 @@ Status QkvToContext( // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) : parameters.scale; +#if USE_LEAN_ATTENTION + if (data.use_lean_attention) { + return LeanAttention(device_prop, stream, parameters, data, scale); + } +#endif #if USE_FLASH_ATTENTION if (data.use_flash_attention) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index fcc9af9681223..7d111a1ee21bf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -53,6 +53,7 @@ size_t GetAttentionWorkspaceSize( size_t total_sequence_length, void* fused_runner, bool use_flash_attention, + bool use_lean_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention, bool use_cudnn_flash_attention, @@ -102,6 +103,19 @@ struct AttentionData { T* softmax_lse_accum = nullptr; T* out_accum = nullptr; + // Flash Atttention and Lean Attention + int num_splits; + + // Lean Attention + bool use_lean_attention = false; +#if USE_LEAN_ATTENTION + int grid_dim_z = 0; + int max_tiles_per_tb = 0; + int high_load_tbs = 0; + int tiles_per_head = 0; + int* lean_sync_flag = nullptr; +#endif + // For Debugging size_t workspace_bytes = 0; bool allow_debug_info = false; @@ -115,6 +129,7 @@ struct AttentionData { void PrintDebugInfo() const { std::cout << "flash=" << use_flash_attention + << ", lean=" << use_lean_attention << ", efficient=" << use_memory_efficient_attention << ", fused_runner=" << (fused_runner != nullptr) << ", fused_cross=" << (fused_cross_attention_kernel != nullptr) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc index 7d21451df5b86..8b8b764e7c785 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc @@ -17,6 +17,9 @@ namespace onnxruntime { void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) { if (value > 0) { use_flash_attention_ = (value & static_cast(AttentionBackend::FLASH_ATTENTION)) > 0; +#if USE_LEAN_ATTENTION + use_lean_attention_ = (value & static_cast(AttentionBackend::LEAN_ATTENTION)) > 0; +#endif use_efficient_attention_ = (value & static_cast(AttentionBackend::EFFICIENT_ATTENTION)) > 0; use_trt_fused_attention_ = (value & static_cast(AttentionBackend::TRT_FUSED_ATTENTION)) > 0; use_cudnn_flash_attention_ = (value & static_cast(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0; @@ -26,6 +29,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che use_trt_causal_attention_ = (value & static_cast(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0; } else { use_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFlashAttention, false); +#if USE_LEAN_ATTENTION + use_lean_attention_ = ParseEnvironmentVariableWithDefault(kEnableLeanAttention, false); +#endif use_efficient_attention_ = !ParseEnvironmentVariableWithDefault(kDisableMemoryEfficientAttention, false); use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedSelfAttention, false); use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault(kEnableCudnnFlashAttention, false); @@ -61,6 +67,10 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che use_flash_attention_ = false; #endif +#ifndef USE_LEAN_ATTENTION + use_lean_attention_ = false; +#endif + #ifndef USE_MEMORY_EFFICIENT_ATTENTION use_efficient_attention_ = false; #endif @@ -81,6 +91,9 @@ void AttentionKernelOptions::Print() const { std::stringstream sstream; sstream << "AttentionKernelOptions:"; sstream << " FLASH_ATTENTION=" << int(use_flash_attention_); +#if USE_LEAN_ATTENTION + sstream << " LEAN_ATTENTION=" << int(use_lean_attention_); +#endif sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_); sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_); sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_); @@ -131,6 +144,10 @@ void AttentionKernelDebugInfo::Print(const char* operator_name, sstream << " SdpaKernel="; if (use_flash_attention.has_value() && use_flash_attention.value()) { sstream << "FLASH_ATTENTION"; +#if USE_LEAN_ATTENTION + } else if (use_lean_attention.has_value() && use_lean_attention.value()) { + sstream << "LEAN_ATTENTION"; +#endif } else if (use_efficient_attention.has_value() && use_efficient_attention.value()) { sstream << "EFFICIENT_ATTENTION"; } else if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h index a27fb199a6272..caed704564c3b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -9,6 +9,7 @@ namespace onnxruntime { struct AttentionKernelDebugInfo { std::optional use_flash_attention = std::nullopt; + std::optional use_lean_attention = std::nullopt; std::optional use_efficient_attention = std::nullopt; std::optional use_trt_fused_attention = std::nullopt; std::optional use_cudnn_flash_attention = std::nullopt; @@ -24,6 +25,7 @@ class AttentionKernelOptions { void InitializeOnce(int sdpa_kernel, bool use_build_flag, bool check_cudnn_version = false); bool UseFlashAttention() const { return use_flash_attention_; } + bool UseLeanAttention() const { return use_lean_attention_; } bool UseEfficientAttention() const { return use_efficient_attention_; } bool UseTrtFusedAttention() const { return use_trt_fused_attention_; } bool UseCudnnFlashAttention() const { return use_cudnn_flash_attention_; } @@ -44,6 +46,7 @@ class AttentionKernelOptions { private: bool use_flash_attention_{true}; + bool use_lean_attention_{false}; bool use_efficient_attention_{true}; bool use_trt_fused_attention_{true}; bool use_cudnn_flash_attention_{false}; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index a079076f2881b..c8c0191967d40 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -384,6 +384,7 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, if (data.use_memory_efficient_attention || data.use_flash_attention || + data.use_lean_attention || data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Use oiginal Query (BSNH) since there is no bias. data.q = const_cast(data.query); diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/block_info.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/block_info.h new file mode 100644 index 0000000000000..6d9ed824b4b76 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/block_info.h @@ -0,0 +1,45 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace onnxruntime { +namespace lean { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + template + __device__ BlockInfo(const Params& params, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]), actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , + seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])), + actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { + } + + template + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/flash.h new file mode 100644 index 0000000000000..a2058d8805ebd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/flash.h @@ -0,0 +1,148 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +namespace onnxruntime { +namespace lean { + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; + void* __restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + // The O matrix (output). + void* __restrict__ o_ptr; + void* __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void* __restrict__ p_ptr; + + // The pointer to the softmax sum. + void* __restrict__ softmax_lse_ptr; + void* __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int* __restrict__ cu_seqlens_q; + int* __restrict__ cu_seqlens_k; + + // If provided, the actual length of each k sequence. + int* __restrict__ seqused_k; + + int* __restrict__ blockmask; + + // The K_new and V_new matrices. + void* __restrict__ knew_ptr; + void* __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr; + void* __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx; + + // Paged KV cache + int* __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Local window size + int window_size_left, window_size_right; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t* rng_state; + + bool is_bf16; + bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version and lean + + void* __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; + + // LEAN Additional Params + int lean_griddimz; + int tiles_per_head; + int max_tiles_per_tb; + int high_load_tbs; + void* __restrict__ sync_flag; + + const cudaDeviceProp* dprops = nullptr; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_lean_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/kernel_traits.h new file mode 100644 index 0000000000000..85be5d3e031ac --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/kernel_traits.h @@ -0,0 +1,315 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = int64_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom>; +#else + using MMA_Atom_Arch = MMA_Atom; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, _1, _1>>, // 4x1x1 or 8x1x1 thread group + Tile, _16, _16>>; + + using SmemLayoutAtomQ = decltype(composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + using SmemLayoutVtransposed = decltype(composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + + using SmemLayoutAtomO = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; + + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemOSize = size(SmemLayoutO{}) * sizeof(ElementAccum); + // static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize + kSmemOSize; + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride<_8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride<_16, _1>>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group + Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; + + using SmemLayoutAtomQdO = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKV = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutKtransposed = decltype(composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + // Temporarily disabling this for hdim 256 on sm86 and sm89 + // static_assert(kBlockN >= 64); + static_assert(kBlockN >= 32); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposed = decltype(composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); + + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutQdOtransposed = decltype(composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); + + using SmemLayoutAtomdKV = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + // Double buffer for sQ + static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride<_8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride<_16, _1>>>; + using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom{}, + Layout, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout>{})); // Val layout, 1 val per store +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.cc b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.cc new file mode 100644 index 0000000000000..81301ebc7ba64 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.cc @@ -0,0 +1,453 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Modifications: support lean attention. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_LEAN_ATTENTION + +#include "contrib_ops/cuda/bert/lean_attention/lean_api.h" +#include + +#include "contrib_ops/cuda/bert/lean_attention/flash.h" +#include "contrib_ops/cuda/bert/lean_attention/static_switch.h" + +namespace onnxruntime { +namespace lean { + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +void set_params_fprop(Flash_fwd_params& params, + // sizes + size_t batch_size, + size_t seqlen_q, + size_t seqlen_k, + size_t seqlen_q_rounded, + size_t seqlen_k_rounded, + size_t num_heads, + size_t num_heads_k, + size_t head_size, + size_t head_size_rounded, + // device pointers + void* q, + void* k, + void* v, + void* out, + void* cu_seqlens_q_d, + void* cu_seqlens_k_d, + void* seqused_k, + void* p_d, + void* softmax_lse_d, + float softmax_scale, + bool is_causal, + bool is_bf16, + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { + // Set the pointers and strides. + params.q_ptr = q; + params.k_ptr = k; + params.v_ptr = v; + params.o_ptr = out; + + params.is_bf16 = is_bf16; + + // All stride are in elements, not bytes. + if (kv_bsnh) { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = num_heads_k * head_size; + params.v_row_stride = num_heads_k * head_size; + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } else { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = head_size; + params.v_row_stride = head_size; + params.q_head_stride = head_size; + params.k_head_stride = seqlen_k * head_size; + params.v_head_stride = seqlen_k * head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + } else { + params.q_batch_stride = 0; + params.k_batch_stride = 0; + params.v_batch_stride = 0; + params.o_batch_stride = 0; + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4267) // Ignore conversion from 'size_t' to 'int', possible loss of data +#pragma warning(disable : 4244) // Ignore conversion from 'double' to 'float', possible loss of data +#endif + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.h_h_k_ratio = num_heads / num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API separates + // local and causal, meaning when we have local window size + params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; +} + +size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) { + size_t bytes = sizeof(float) * batch_size * num_heads * seqlen; + return bytes; +} + +size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; + return bytes; +} + +size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, + size_t seqlen_q, size_t head_size_rounded) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded; + return bytes; +} + +size_t get_sync_flag_size(size_t num_m_blocks, size_t batch_size, size_t num_heads) { + size_t bytes = sizeof(int) * batch_size * num_heads * num_m_blocks; + return bytes; +} + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_lean_dispatch(params, stream); + }); + }); +} + +std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t max_seqlen_q, size_t max_seqlen_k, + size_t num_heads, size_t num_heads_k, size_t head_size, size_t num_SMs, bool is_causal) { + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int block_m = head_size <= 64 ? 64 : (head_size <= 128 ? 64 : 64); + const int num_m_blocks = (max_seqlen_q + block_m - 1) / block_m; + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + if (max_seqlen_q == 1) { + is_causal = false; + } + + max_seqlen_q = max_seqlen_q * num_heads / num_heads_k; + +#if defined(DEBUG_LEAN_ATTENTION) + printf("block_n: %d\n", block_n); + printf("block_m: %d\n", block_m); + printf("num_m_blocks: %d\n", num_m_blocks); + printf("num_n_blocks: %d\n", num_n_blocks); + printf("max_seqlen_q: %lu\n", max_seqlen_q); + printf("max_seqlen_k: %lu\n", max_seqlen_k); + printf("is_causal: %d\n", is_causal); + printf("num_heads: %lu\n", num_heads); + printf("num_heads_k: %lu\n", num_heads_k); +#endif + + size_t tiles_per_head = 0; + if (is_causal) { + // Prefill - Causal + for (int i = 0; i < num_m_blocks; i++) { + tiles_per_head += (((i + 1) * block_m) + block_n - 1) / block_n; + } + } else { + // Decode or Not Causal + // Tiles per head is the number of blocks in the first block + tiles_per_head = num_m_blocks * num_n_blocks; + } + size_t total_tiles = tiles_per_head * batch_size * num_heads_k; + + // StreamK Lean has as many threadblocks as SMs + // This should be a function of tile size and number of scratchpad space + + // We want at least two tiles per CTA to be efficient + // And then 2 CTAs per SM + size_t lean_griddimz = num_SMs * 2; + if (total_tiles <= 2 * 2 * num_SMs) { + lean_griddimz = std::min((total_tiles + 1) / 2, (32 * total_tiles + num_n_blocks - 1) / num_n_blocks); + // params.lean_griddimz = num_m_blocks * batch_size * num_heads; + } else { + // Max split of 64 per block is allowed, so we conservatively set it to 32 + // to account for ceil + lean_griddimz = std::min(2 * num_SMs, 32 * num_heads_k * batch_size * num_m_blocks); + } + size_t max_tiles_per_tb = (total_tiles + lean_griddimz - 1) / lean_griddimz; + // Find max number of splits + size_t num_splits = 0; + if (total_tiles % lean_griddimz == 0) { + num_splits = 1 + ((num_n_blocks + max_tiles_per_tb - 2) / (max_tiles_per_tb)); + } else { + num_splits = 1 + ((num_n_blocks + max_tiles_per_tb - 3) / (max_tiles_per_tb - 1)); + } + size_t high_load_tbs = total_tiles - ((max_tiles_per_tb - 1) * lean_griddimz); + +#if defined(DEBUG_LEAN_ATTENTION) + printf("Causal: %d params.tiles_per_head : %lu\n", is_causal, tiles_per_head); + printf("num_splits = %lu\n", num_splits); + printf("total_tiles = %lu\n", total_tiles); + printf("lean_griddimz = %lu\n", lean_griddimz); + printf("max_tiles_per_tb = %lu\n", max_tiles_per_tb); + printf("high_load_tbs = %lu\n", high_load_tbs); +#endif + + if (num_splits > 1) { + size_t softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads_k, max_seqlen_q); + auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; }; + const size_t head_size_rounded = round_multiple(head_size, 32); + size_t out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads_k, max_seqlen_q, head_size_rounded); + size_t sync_flag_bytes = get_sync_flag_size(num_m_blocks, batch_size, num_heads_k); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes, sync_flag_bytes, lean_griddimz, max_tiles_per_tb, high_load_tbs, tiles_per_head}; + } else { + return {0, 0, 0, 0, lean_griddimz, max_tiles_per_tb, high_load_tbs, tiles_per_head}; + } +} + +bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k) { + bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; + bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + return (is_sm8x || is_sm90) && (head_size == 64 || head_size == 128) && (num_heads % num_heads_k == 0); +} + +// This API is used when past key and value are present... since cached, these are assumed to have sequence length +// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_. +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k x seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + int* block_table, // batch_size x max_num_blocks_per_seq + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + int rotary_dim, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits, + int grid_dimz, + int max_tiles_per_tb, + int high_load_tbs, + int tiles_per_head, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int* sync_flag, + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv, + int max_num_blocks_per_seq, + int page_block_size) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + const bool paged_KV = block_table != nullptr; + +#if defined(DEBUG_LEAN_ATTENTION) + printf( + "batch_size: %d num_heads %d num_heads_k %d head_size %d seqlen_q %d seqlen_k %d seqlen_k_new %d " + "softmax_scale %f is_causal %d is_bf16 %d past_bsnh %d num_splits %d grid_dimz %d max_tiles_per_tb %d " + "high_load_tbs %d tiles_per_head %d local_window_size %d is_rotary_interleaved %d is_packed_qkv %d " + "max_num_blocks_per_seq %d page_block_size %d\n", + batch_size, num_heads, num_heads_k, head_size, seqlen_q, seqlen_k, seqlen_k_new, + softmax_scale, is_causal, is_bf16, past_bsnh, num_splits, grid_dimz, max_tiles_per_tb, + high_load_tbs, tiles_per_head, local_window_size, is_rotary_interleaved, is_packed_qkv, + max_num_blocks_per_seq, page_block_size); +#endif + + // Lean attention treats decode as non-causal + if (seqlen_q == 1) { + is_causal = false; + } + + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_k; + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + // In kv-cache case, seqlen_k_max as kv sequence length + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, kcache, vcache, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + past_bsnh, + local_window_size, + is_causal ? 0 : -1); + params.dprops = &dprops; + + if (k_new != nullptr && v_new != nullptr) { + params.seqlen_knew = seqlen_k_new; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; + // All stride are in elements, not bytes. + params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } + params.knew_head_stride = head_size; + params.vnew_head_stride = head_size; + } else { + params.seqlen_knew = 0; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + } + + if (seqlenq_ngroups_swapped) { + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads_k * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + } else { + params.q_batch_stride = seqlen_q * num_heads_k * head_size; + } + params.q_row_stride = head_size; + params.q_head_stride = seqlen_q * head_size; + params.o_row_stride = head_size; + params.o_head_stride = seqlen_q * head_size; + params.o_batch_stride = seqlen_q * num_heads_k * head_size; + } + + params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; + if (seqlens_k_ != nullptr) { + params.cu_seqlens_k = static_cast(seqlens_k_); + } + + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = rotary_dim; + } + + params.num_splits = num_splits; + params.lean_griddimz = grid_dimz; + params.max_tiles_per_tb = max_tiles_per_tb; + params.high_load_tbs = high_load_tbs; + params.tiles_per_head = tiles_per_head; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + params.sync_flag = sync_flag; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + params.alibi_slopes_ptr = nullptr; + if (paged_KV) { + params.block_table = block_table; // TODO(aciddelgado): cast to int pointer + params.block_table_batch_stride = max_num_blocks_per_seq; + // params.num_blocks = num_blocks; + params.page_block_size = page_block_size; + params.k_batch_stride = page_block_size * num_heads_k * head_size; + params.v_batch_stride = page_block_size * num_heads_k * head_size; + } else { + params.block_table = nullptr; + params.block_table_batch_stride = 0; + // params.num_blocks = 0; + params.page_block_size = 1; + } + + // Only split kernel supports appending to KV cache + run_mha_fwd(params, stream); + + return Status::OK(); +} + +} // namespace lean +} // namespace onnxruntime + +#endif // USE_LEAN_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.h new file mode 100644 index 0000000000000..3b9bd1c24f08c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_api.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if USE_LEAN_ATTENTION + +#include "core/providers/cuda/cuda_common.h" +#include + +namespace onnxruntime { +namespace lean { + +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* k, // batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + int* block_table, // batch_size x max_num_blocks_per_seq + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + int rotary_dim, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits = 0, + int grid_dimz = 0, + int max_tiles_per_tb = 0, + int high_load_tbs = 0, + int tiles_per_head = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int* sync_flag = nullptr, + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false, + int max_num_blocks_per_seq = 0, + int page_block_size = 1); + +size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads); + +std::tuple +get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads, + size_t num_heads_k, size_t head_size, size_t num_SMs, bool is_causal); + +bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k); + +} // namespace lean +} // namespace onnxruntime + +#endif // USE_LEAN_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu new file mode 100644 index 0000000000000..cfcacbabb3cb9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim128_fp16.cu @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_LEAN_ATTENTION + +#include "contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h" + +namespace onnxruntime { +namespace lean { + +template void run_mha_fwd_lean_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu new file mode 100644 index 0000000000000..44c870f6ab35b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_hdim64_fp16.cu @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_LEAN_ATTENTION + +#include "contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h" + +namespace onnxruntime { +namespace lean { + +template void run_mha_fwd_lean_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h new file mode 100644 index 0000000000000..5be69ea0af55c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h @@ -0,0 +1,1066 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "contrib_ops/cuda/bert/lean_attention/block_info.h" +#include "contrib_ops/cuda/bert/lean_attention/kernel_traits.h" +#include "contrib_ops/cuda/bert/lean_attention/utils.h" +#include "contrib_ops/cuda/bert/lean_attention/softmax.h" +#include "contrib_ops/cuda/bert/lean_attention/mask.h" + +namespace onnxruntime { +namespace lean { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialized for Prefill +template +inline __device__ void lean_compute_attn_impl_ver3(const Params& params, const int cta_id, int start_tile_gid, int start_tile_hid, int num_tiles, const int num_tiles_per_head) { +#if defined(DEBUG_LEAN_ATTENTION) + // Timing + auto kernel_start = clock64(); + long long int comp1_duration = 0; + long long int comp2_duration = 0; + long long int epilogue_duration = 0; + long long int prologue_duration = 0; + long long int epil1_duration = 0; + long long int epil2_duration = 0; + long long int epil3_duration = 0; + + const int tracing_block = 0; +#endif + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = typename Kernel_traits::GmemTiledCopyO; + using GmemTiledCopyOaccum = typename Kernel_traits::GmemTiledCopyOaccum; + + const int num_m_blocks_per_head = (params.seqlen_q + kBlockM - 1) / kBlockM; + + // // This is the solution to the summation series (n+1)(n+2)/2 = start_tile_hid + 1 + // int cur_m_block = Is_causal ? (int)ceilf((sqrtf(9 + (8*start_tile_hid)) - 3) / 2) : start_tile_hid/num_tiles_per_head; + float block_scale = (float)kBlockM / (float)kBlockN; + int cur_m_block = Is_causal ? kBlockM > kBlockN ? (int)ceilf((sqrtf(1 + (8 * start_tile_hid + 8) / block_scale) - 3) / 2) + // : (int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) * (1 / block_scale) + (int)((start_tile_hid - (1 / block_scale) * ((int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) * ((int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) + 1) / 2)) / ((int)((-1 + sqrt(1 + 8 * block_scale * start_tile_hid)) / 2) + 1)) + : static_cast((-1 + sqrt(1 + 8 * start_tile_hid * block_scale)) / (2 * block_scale)) + : start_tile_hid / num_tiles_per_head; + int num_tiles_in_block = Is_causal ? (int)ceilf(block_scale * (cur_m_block + 1)) : num_tiles_per_head; + int cur_bidb = start_tile_gid / (num_tiles_per_head * params.h); + int cur_bidh = (start_tile_gid - (cur_bidb * num_tiles_per_head * params.h)) / num_tiles_per_head; + + int num_tiles_left = num_tiles; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Debugging block = %d\n", tracing_block); + printf("kBlockM = %d\n", kBlockM); + printf("kBlockN = %d\n", kBlockN); + printf("kHeadDim = %d\n", kHeadDim); + printf("kNWarps = %d\n", kNWarps); + printf("IsEvenMN = %d\n", Is_even_MN); + printf("block_scale = %f\n", block_scale); + printf("seq_len_q -change = %d\n", params.seqlen_q); + printf("seq_len_k = %d\n", params.seqlen_k); + printf("q_batch_stride = %ld\n", params.q_batch_stride); + printf("q_head_stride = %ld\n", params.q_head_stride); + printf("q_row_stride = %ld\n", params.q_row_stride); + printf("k_batch_stride = %ld\n", params.k_batch_stride); + printf("k_head_stride = %ld\n", params.k_head_stride); + printf("k_row_stride = %ld\n", params.k_row_stride); + printf("v_row_stride = %ld\n", params.v_row_stride); + printf("o_row_stride = %ld\n", params.o_row_stride); + printf("start_m_block = %d\n", cur_m_block); + printf("start_tile_gid = %d\n", start_tile_gid); + printf("start_tile_hid = %d\n", start_tile_hid); + printf("cur_bidb = %d/%d\n", cur_bidb, params.b); + printf("cur_bidh = %d/%d\n", cur_bidh, params.h); + printf("num_m_blocks_per_head = %d\n", num_m_blocks_per_head); + printf("cur_m_block = %d\n", cur_m_block); + printf("num_tiles_in_block = %d\n", num_tiles_in_block); + printf("Total tiles = %d\n", num_tiles); + } +#endif + + // Prologue + int n_tile_min = kBlockM > kBlockN ? start_tile_hid - (block_scale * cur_m_block * (cur_m_block + 1) / 2) + : start_tile_hid - (int)(((int)floorf(cur_m_block * block_scale) * ((int)floorf(cur_m_block * block_scale) + 1) / 2) / block_scale) - ((cur_m_block % int(1 / block_scale)) * (floorf(cur_m_block * block_scale) + 1)); + int n_tile = n_tile_min + num_tiles_left - 1 >= num_tiles_in_block ? num_tiles_in_block - 1 : n_tile_min + num_tiles_left - 1; + + index_t row_offset_q = cur_bidb * params.q_batch_stride + + +cur_m_block * kBlockM * params.q_row_stride + cur_bidh * params.q_head_stride; + index_t row_offset_k = cur_bidb * params.k_batch_stride + + +n_tile * kBlockN * params.k_row_stride + (cur_bidh / params.h_h_k_ratio) * params.k_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + + // PREDICATES + // + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // // Start from the last block of first head + // lean::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + // params.seqlen_q - cur_m_block * kBlockM); + + // // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + // lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + // params.seqlen_k - n_tile * kBlockN); + // cute::cp_async_fence(); + + index_t row_offset_v = cur_bidb * params.v_batch_stride + + +n_tile * kBlockN * params.v_row_stride + (cur_bidh / params.h_h_k_ratio) * params.v_head_stride; + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + // Tiled Matrix Multiply + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling - Can be moved + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("n_tile_min = %d\n", n_tile_min); + printf("n_tile = %d\n", n_tile); + printf("row_offset_q = %" PRId64 "\n", row_offset_q); + printf("row_offset_k = %" PRId64 "\n", row_offset_k); + printf("row_offset_v = %" PRId64 "\n", row_offset_v); + } + + int num_blocks = 0; +#endif + + for (; num_tiles_left > 0;) { +#if defined(DEBUG_LEAN_ATTENTION) + num_blocks += 1; + auto prologue_start = clock64(); +#endif + + cur_bidb = start_tile_gid / (num_tiles_per_head * params.h); + cur_bidh = (start_tile_gid - (cur_bidb * num_tiles_per_head * params.h)) / num_tiles_per_head; + // Scheduling Policy - below + + // Calculate split ID + int block_start_gid = start_tile_gid - n_tile_min; + int cta_id_block_start = block_start_gid > params.high_load_tbs * params.max_tiles_per_tb + ? params.high_load_tbs + ((block_start_gid - (params.high_load_tbs * params.max_tiles_per_tb)) / (params.max_tiles_per_tb - 1)) + : block_start_gid / params.max_tiles_per_tb; + int n_split_idx = cta_id - cta_id_block_start; + + // Check host/ + int host_cta = 0; + int total_splits = 1; + if (n_tile_min == 0) { + host_cta = 1; + int block_end_gid = start_tile_gid + num_tiles_in_block - 1; + int cta_id_block_end = block_end_gid > params.high_load_tbs * params.max_tiles_per_tb + ? params.high_load_tbs + ((block_end_gid - (params.high_load_tbs * params.max_tiles_per_tb)) / (params.max_tiles_per_tb - 1)) + : block_end_gid / params.max_tiles_per_tb; + total_splits = cta_id_block_end - cta_id + 1; + } + + int end_cta = 0; + if (n_tile == num_tiles_in_block - 1) { + end_cta = 1; + } + + start_tile_gid += n_tile - n_tile_min + 1; + start_tile_hid += n_tile - n_tile_min + 1; + if (start_tile_hid >= num_tiles_per_head) { + // Next head + start_tile_hid = 0; + } + num_tiles_left -= n_tile - n_tile_min + 1; + + const BlockInfo binfo(params, cur_bidb); + // This is a hack, we really need to handle this outside the kernel + // But can't figure out a way to get actual seqlen_k in host-side code. + int max_actual_tiles = (binfo.actual_seqlen_k + kBlockN - 1) / kBlockN; + int num_actual_tiles_in_block = Is_causal ? std::max(max_actual_tiles, (int)ceilf(block_scale * (cur_m_block + 1))) : max_actual_tiles; + if (n_tile >= max_actual_tiles) { + tKgK.data() = tKgK.data() + (-int((n_tile - max_actual_tiles - 1) * kBlockN * params.k_row_stride)); + tVgV.data() = tVgV.data() + (-int((n_tile - max_actual_tiles - 1) * kBlockN * params.v_row_stride)); + n_tile = max_actual_tiles - 1; + } + if constexpr (Append_KV) { + if (end_cta) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, cur_bidb) + (n_tile * kBlockN) * params.knew_row_stride + (cur_bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, cur_bidb) + (n_tile * kBlockN) * params.vnew_row_stride + (cur_bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); + } +#endif + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_tile_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && (blockIdx.z == tracing_block || blockIdx.z == tracing_block + 1)) { + printf("Block %d n_tile_min %d n_tile %d n_block_copy_min %d\n", blockIdx.z, n_tile_min, n_tile, n_block_copy_min); + } +#endif + for (int n_block = n_tile; n_block >= n_block_copy_min; n_block--) { + lean::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + + lean::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; + } + } + lean::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - cur_m_block * kBlockM); + lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_tile * kBlockN); + cute::cp_async_fence(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("##### CTA : %d\n", blockIdx.z); + printf("cur_bidb = %d/%d\n", cur_bidb, params.b); + printf("cur_bidh = %d/%d\n", cur_bidh, params.h); + printf("cur_m_block = %d\n", cur_m_block); + printf("seqlen_k_cache = %d\n", binfo.seqlen_k_cache); + printf("actual_seqlen_q = %d\n", binfo.actual_seqlen_q); + printf("actual_seqlen_k = %d\n", binfo.actual_seqlen_k); + printf("num_tiles_in_block = %d\n", num_tiles_in_block); + printf("n_tile(new) = %d\n", n_tile); + printf("n_tile_min = %d\n", n_tile_min); + printf("host_cta = %d\n", host_cta); + printf("end_cta = %d\n", end_cta); + printf("n_split_idx = %d\n", n_split_idx); + printf("total_splits = %d\n", total_splits); + printf("\n#### For next block:\n"); + printf("start_tile_gid = %d\n", start_tile_gid); + printf("start_tile_hid = %d\n", start_tile_hid); + printf("num_tiles_left = %d\n", num_tiles_left); + printf("\n"); + } +#endif + + // All scheduling policy decisions should be made above this line + clear(acc_o); + + lean::Softmax<2 * size<1>(acc_o)> softmax; + + lean::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, 0.0f); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + lean::cp_async_wait<0>(); + __syncthreads(); + +#if defined(DEBUG_LEAN_ATTENTION) + prologue_duration += clock64() - prologue_start; + auto compute_start = clock64(); +#endif + + // Clear the smem tiles to account for predicated off loads + lean::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_tile * kBlockN); + cute::cp_async_fence(); + + lean::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Tile 0 - Svalue: acc_s[0] = %f\n", acc_s(0)); + } +#endif + + mask.template apply_mask( + acc_s, n_tile * kBlockN, cur_m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); + + lean::cp_async_wait<0>(); + __syncthreads(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + print(tVsV); + } + // __syncthreads(); +#endif + + if (n_tile > n_tile_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Tile 0 - PValue[0] = %f\n", acc_s(0)); + } +#endif + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = lean::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), lean::convert_layout_acc_Aregs(rP.layout())); + + lean::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Tile 0 - AfterPV[0] = %f\n", acc_o(0)); + } +#endif + + n_tile -= 1; + +#if defined(DEBUG_LEAN_ATTENTION) + comp1_duration += clock64() - compute_start; + compute_start = clock64(); +#endif + + // These are the iterations where we don't need masking on S + for (; n_tile >= n_tile_min; --n_tile) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + lean::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + + lean::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + lean::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ntile %d Svalue: acc_s[0] = %f\n", n_tile, acc_s(0)); + } +#endif + + lean::cp_async_wait<0>(); + __syncthreads(); + if (n_tile > n_tile_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + lean::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask( + acc_s, n_tile * kBlockN, cur_m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ntile %d Pvalue: acc_s[0] = %f\n", n_tile, acc_s(0)); + } +#endif + Tensor rP = lean::convert_type(acc_s); + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), lean::convert_layout_acc_Aregs(rP.layout())); + + lean::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ntile %d AfterPV[0] = %f\n", n_tile, acc_o(0)); + } +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + // Epilogue + comp2_duration += clock64() - compute_start; + auto epilogue_start = clock64(); +#endif + + if (host_cta && end_cta) { +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("acc_o[0] = %f\n", acc_o(0)); + } +#endif + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("lse[0] = %f\n", lse(0)); + printf("acc_o[0] = %f\n", acc_o(0)); + } +#endif + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = lean::convert_type(acc_o); + + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = cur_bidb * params.o_batch_stride + + cur_m_block * kBlockM * params.o_row_stride + cur_bidh * params.o_head_stride; + + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("tOpO[0] = %d\n", tOpO(0)); + printf("tOrO[0] = %f\n", tOrO(0)); + } +#endif + // Clear_OOB_K must be false since we don't want to write zeros to gmem + lean::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, params.seqlen_q - cur_m_block * kBlockM); + // epil1_duration += clock64() - epilogue_start; + } else if (!host_cta) { + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = typename Kernel_traits::SmemCopyAtomOaccum; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = lean::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + __syncthreads(); + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_oaccum = (((index_t)(n_split_idx * params.b + cur_bidb) * params.h + cur_bidh) * params.seqlen_q + cur_m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + cur_bidb) * params.h + cur_bidh) * params.seqlen_q + cur_m_block * kBlockM; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("n_split_idx = %d\n", n_split_idx); + // printf("row_offset_o = %" PRId64 "\n", row_offset_o); + printf("row_offset_oaccum = %" PRId64 "\n", row_offset_oaccum); + printf("row_offset_lseaccum = %" PRId64 "\n", row_offset_lseaccum); + } +#endif + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + (row_offset_oaccum)), + Shape, Int>{}, + make_stride(kHeadDim, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + // This partitioning is unequal because only threads 0,4,8,etc write to gLSE + // and the rest are unused. + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < params.seqlen_q - cur_m_block * kBlockM) { + gLSEaccum(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + lean::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - cur_m_block * kBlockM); + + __threadfence(); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && (blockIdx.z == tracing_block || blockIdx.z == tracing_block + 1)) { + printf("Block %d Writing Flag %d\n", blockIdx.z, (cur_bidb * params.h * num_m_blocks_per_head) + (cur_bidh * num_m_blocks_per_head) + cur_m_block); + } +#endif + + atomicAdd(reinterpret_cast(params.sync_flag) + (cur_bidb * params.h * num_m_blocks_per_head) + (cur_bidh * num_m_blocks_per_head) + cur_m_block, 1); + +#if defined(DEBUG_LEAN_ATTENTION) + epil2_duration += clock64() - epilogue_start; +#endif + } else { + constexpr int kNThreads = Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + //////////////////////////////////////////////////////////////////////////////// +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Before LSE acc_o[0] = %f\n", acc_o(0)); + } +#endif + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("After LSE acc_o[0] = %f\n", acc_o(0)); + printf("lse[0] = %f\n", lse(0)); + } +#endif + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = typename Kernel_traits::SmemCopyAtomOaccum; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = lean::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + __syncthreads(); + + // We move to SMEM and back because we need equal distribution of + // accum registers. Initially only threads 0,4,8,etc have oaccum values. + // So, first move them to SMEM. + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_oaccum = ((cur_bidb * params.h + cur_bidh) * (index_t)params.seqlen_q + cur_m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = (cur_bidb * params.h + cur_bidh) * (index_t)params.seqlen_q + cur_m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + (row_offset_oaccum)), + Shape, Int>{}, + make_stride(kHeadDim, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("Block %d row_offset_oaccum = %" PRId64 "\n", blockIdx.z, row_offset_oaccum); + printf("Block %d row_offset_lseaccum = %" PRId64 "\n", blockIdx.z, row_offset_lseaccum); + } +#endif + + // GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + // auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + // Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + // Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOgOaccumReg = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccumReg)); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("First split t0g0accum.data() %p\n", tOgOaccum.data()); + } +#endif + + __syncthreads(); + + // Bring the oaccum back from SMEM to registers + // Now all threads have oaccum values equaly distributed. + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + ///////////////////////////////////////////////////////////////////////////// + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + Tensor sLSE = make_tensor(sV.data(), Shape, Int>{}); // (SMEM_M,SMEM_N) + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + + // This partitioning is unequal because only threads 0,4,8,etc write to gLSE + // and the rest are unused. + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int col = get<0>(taccOcO_row(mi)); + if (col < params.seqlen_q - cur_m_block * kBlockM) { + sLSE(0, col) = lse(mi); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("threadIdx.x %d col %d mi%d slSE %f\n", threadIdx.x, col, mi, lse(mi)); + } +#endif + } + } + } + + // Synchronize here to make sure all atomics are visible to all threads. + // Not exactly sure why we need this, but it seems to be necessary. + __threadfence(); + while (atomicAdd(reinterpret_cast(params.sync_flag) + + (cur_bidb * params.h * num_m_blocks_per_head) + + (cur_bidh * num_m_blocks_per_head) + cur_m_block, + 0) < (total_splits - 1) * kNThreads) { + __threadfence(); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x % 32 == 0 && blockIdx.z == tracing_block) { + printf("Waiting Block: %d target-value: %d\n", blockIdx.z, (total_splits - 1) * kNThreads); + } +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + // Print sync flag value + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + int32_t sync_flag = atomicAdd(reinterpret_cast(params.sync_flag) + + (cur_bidb * params.h * num_m_blocks_per_head) + + (cur_bidh * num_m_blocks_per_head) + cur_m_block, + 0); + if (threadIdx.x % 32 == 0 && blockIdx.z == tracing_block) { + printf("Sync flag value: %d\n", sync_flag); + } + } +#endif + + Tensor gLSEaccumRead = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape, Int>{}, + make_stride(params.b * params.h * params.seqlen_q, _1{})); + // Read the LSE values from gmem and store them in shared memory, then tranpose them. + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // R + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; // R + +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + // We skip the first row = 0, as we already populated it in shared memory. + ElementAccum lse = (row > 0 && row < total_splits && col < params.b * params.h * (index_t)params.seqlen_q - row_offset_lseaccum) ? gLSEaccumRead(row, col) : -INFINITY; + if (row > 0 && row < kMaxSplits) { + sLSE(row, col) = lse; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x % 32 == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d l %d row %d col %d lse %f\n", threadIdx.x, l, row, col, lse); + } +#endif + } + } + __syncthreads(); // For all LSEs to reach shared memory + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("kNLsePerThread %d kRowsPerLoadLSE %d kRowsPerLoadTranspose %d\n", kNLsePerThread, kRowsPerLoadLSE, kRowsPerLoadTranspose); + } +#endif + + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row, col) : -INFINITY; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d l %d row %d col %d lse_accum %f\n", threadIdx.x, l, row, col, lse_accum(l)); + } +#endif + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_max = max(lse_max, lse_accum(l)); + } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_sum += expf(lse_accum(l) - lse_max); + } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; +// if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } +// Store the scales exp(lse - lse_logsum) in shared memory. +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < total_splits && col < kBlockM) { + sLSE(row, col) = expf(lse_accum(l) - lse_logsum); + ElementAccum lse_scale = sLSE(row, col); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d l %d row %d col %d lse_accum %f lse_logsum %f sLSE %f\n", threadIdx.x, l, row, col, lse_accum(l), lse_logsum, lse_scale); + } +#endif + } + } + + Tensor tOrO = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { + tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; + } + } + + // Sync here for sLSE stores to go through + __syncthreads(); +// First reduce self Oaccum +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE(0, row); +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d Split %d m %d Row %d k %d i %d LSE %f Oaccum %f O %f\n", threadIdx.x, 0, m, row, k, i, lse_scale, tOrOaccum(i, m, k), tOrO(i, m, k)); + } +#endif + } + } + } + + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * (index_t)params.seqlen_q * params.d_rounded; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("After First Split t0g0accum.data() %p\n", tOgOaccum.data()); + } +#endif + // Load Oaccum in then scale and accumulate to O + // Here m is each row of 0accum along token dimension + // k is + for (int split = 1; split < total_splits; ++split) { + lean::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * (index_t)params.seqlen_q - row_offset_lseaccum); +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE(split, row); +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("ThreadIdx %d Split %d m %d Row %d k %d i %d LSE %f Oaccum %f O %f\n", threadIdx.x, split, m, row, k, i, lse_scale, tOrOaccum(i, m, k), tOrO(i, m, k)); + } +#endif + } + } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * (index_t)params.seqlen_q * params.d_rounded; + } + + Tensor r1 = lean::convert_type(tOrO); + +// Write to gO +#pragma unroll + for (int m = 0; m < size<1>(r1); ++m) { + const int idx = cur_m_block * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.seqlen_q) { + // The index to the rows of Q + const int row = idx; + auto o_ptr = reinterpret_cast(params.o_ptr) + cur_bidb * params.o_batch_stride + cur_bidh * params.o_head_stride + row * params.o_row_stride; +#pragma unroll + for (int k = 0; k < size<2>(r1); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(r1))::value>>{}, Stride<_1>{}); + copy(r1(_, m, k), gO); + } + } + } + } +#if defined(DEBUG_LEAN_ATTENTION) + epil3_duration += clock64() - epilogue_start; +#endif + } + + if (num_tiles_left) { + // We can probably do better than this + // We first decrement the pointers back to starting. + // We can probably just use q_ptr and k_ptr directly. But can't figure out how to do it. + // Without disturbing the gQ, gK, gV tensor pointer CUTE objects. + tQgQ.data() = tQgQ.data() + (-int(row_offset_q)); + tKgK.data() = tKgK.data() + (((num_tiles_in_block - n_tile_min - 1) * kBlockN) * params.k_row_stride - row_offset_k); + tVgV.data() = tVgV.data() + (((num_tiles_in_block - n_tile_min - 1) * kBlockN) * params.v_row_stride - row_offset_v); + cur_m_block = cur_m_block + 1 >= num_m_blocks_per_head ? 0 : cur_m_block + 1; + num_tiles_in_block = Is_causal ? (int)ceilf(block_scale * (cur_m_block + 1)) : num_tiles_per_head; + n_tile = num_tiles_left - 1 >= num_tiles_in_block ? num_tiles_in_block - 1 : num_tiles_left - 1; + n_tile_min = 0; + cur_bidb = start_tile_gid / (num_tiles_per_head * params.h); + cur_bidh = (start_tile_gid - (cur_bidb * num_tiles_per_head * params.h)) / num_tiles_per_head; + + row_offset_q = cur_bidb * params.q_batch_stride + + +cur_m_block * kBlockM * params.q_row_stride + cur_bidh * params.q_head_stride; + row_offset_k = cur_bidb * params.k_batch_stride + + +n_tile * kBlockN * params.k_row_stride + (cur_bidh / params.h_h_k_ratio) * params.k_head_stride; + row_offset_v = cur_bidb * params.v_batch_stride + + +n_tile * kBlockN * params.v_row_stride + (cur_bidh / params.h_h_k_ratio) * params.v_head_stride; + + tQgQ.data() = tQgQ.data() + row_offset_q; + tKgK.data() = tKgK.data() + row_offset_k; + tVgV.data() = tVgV.data() + row_offset_v; + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0 && blockIdx.z == tracing_block) { + printf("#### Ready for next block:\n"); + printf("next_block %d\n", cur_m_block); + printf("n_tile %d\n", n_tile); + printf("row_offset_q = %" PRId64 "\n", row_offset_q); + printf("row_offset_k = %" PRId64 "\n", row_offset_k); + printf("row_offset_v = %" PRId64 "\n", row_offset_v); + } +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + epilogue_duration += clock64() - epilogue_start; +#endif + } + +#if defined(DEBUG_LEAN_ATTENTION) + if (threadIdx.x == 0) { + uint smid; + asm("mov.u32 %0, %smid;" : "=r"(smid)); + printf("%d %d %d %d %lld %lld %lld %lld %lld %lld %lld %lld\n", + blockIdx.z, num_blocks, smid, cta_id, clock64() - kernel_start, prologue_duration, comp1_duration, + comp2_duration, epilogue_duration, epil1_duration, epil2_duration, epil3_duration); + } +#endif +} + +template +inline __device__ void lean_compute_attn(const Params& params) { + // const int cta_id = blockIdx.z < 54 ? 4*blockIdx.z : blockIdx.z < 108 ? 4*(blockIdx.z % 54) + 2 : blockIdx.z < 162 ? 4*(blockIdx.z % 108) + 1 : 4*(blockIdx.z % 162) + 3; + const int cta_id = blockIdx.z; + int start_tile_gid = cta_id < params.high_load_tbs ? params.max_tiles_per_tb * cta_id : (params.max_tiles_per_tb - 1) * cta_id + params.high_load_tbs; + int start_tile_hid = start_tile_gid % params.tiles_per_head; + int num_tiles = cta_id < params.high_load_tbs ? params.max_tiles_per_tb : params.max_tiles_per_tb - 1; + + lean::lean_compute_attn_impl_ver3(params, cta_id, start_tile_gid, start_tile_hid, num_tiles, params.tiles_per_head); +} + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h new file mode 100644 index 0000000000000..fcccb54ebf4e8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/lean_fwd_launch_template.h @@ -0,0 +1,73 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "contrib_ops/cuda/bert/lean_attention/static_switch.h" +#include "contrib_ops/cuda/bert/lean_attention/flash.h" +#include "contrib_ops/cuda/bert/lean_attention/lean_fwd_kernel.h" + +namespace onnxruntime { +namespace lean { + +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ + template \ + __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(lean_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, int kMaxSplits, bool Append_KV) { +#if defined(ARCH_SUPPORTS_FLASH) + lean::lean_compute_attn(params); +#else + FLASH_UNSUPPORTED_ARCH +#endif +} + +template +void run_lean_fwd(Flash_fwd_params& params, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + dim3 grid(1, 1, params.lean_griddimz); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + MAXSPLIT_SWITCH(params.num_splits, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV_Const, [&] { + auto kernel = &lean_fwd_kernel < Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, kMaxSplits, Append_KV_Const > ; + if (2 * smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 2 * smem_size); + } + kernel<<>>(params); + }); + }); + }); + }); + }); +} + +template +void run_mha_fwd_lean_dispatch(Flash_fwd_params& params, cudaStream_t stream) { + // This should be modified according to optimal lean tile size + constexpr static int kBlockM = Headdim <= 64 ? 64 : (Headdim <= 128 ? 64 : 64); + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_lean_fwd>(params, stream); +} + +} // namespace lean +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h new file mode 100644 index 0000000000000..d63c80b012de6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/mask.h @@ -0,0 +1,209 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +namespace onnxruntime { +namespace lean { + +using namespace cute; + +template +__forceinline__ __device__ void apply_mask(Tensor& tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { +// Without the "make_coord" we get wrong results +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +__forceinline__ __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +__forceinline__ __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +__forceinline__ __device__ void apply_mask_causal_w_idx( + Tensor& tensor, Tensor const& idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); +#pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +struct Mask { + const int max_seqlen_k, max_seqlen_q; + const int window_size_left, window_size_right; + const float alibi_slope; + + __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, + const int window_size_left, const int window_size_right, + const float alibi_slope = 0.f) + : max_seqlen_k(max_seqlen_k), max_seqlen_q(max_seqlen_q), window_size_left(window_size_left), window_size_right(window_size_right), alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { + }; + + // Causal_mask: whether this particular iteration needs causal masking + template + __forceinline__ __device__ void apply_mask(Tensor& tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), lean::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } else { +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Has_alibi) { + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; +}; + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h new file mode 100644 index 0000000000000..ad66389848e6e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/softmax.h @@ -0,0 +1,196 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +#include "contrib_ops/cuda/bert/lean_attention/utils.h" + +namespace onnxruntime { +namespace lean { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor& dst, Tensor& src, Operator& op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor& max) { + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor& sum) { + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor& tensor, Tensor const& max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { +// Instead of computing exp(x - max), we compute exp2(x * log_2(e) - +// max * log_2(e)) This allows the compiler to use the ffma +// instruction instead of fadd and fmul separately. +// The following macro will disable the use of fma. +// See: https://github.com/pytorch/pytorch/issues/121558 for more details +// This macro is set in PyTorch and not FlashAttention +#ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); +#else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); +#endif + } + } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ void softmax_rescale_o(Tensor0& acc_s, Tensor1& acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), lean::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + lean::template reduce_max(scores, row_max); + lean::scale_apply_exp2(scores, row_max, softmax_scale_log2); + lean::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + lean::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), lean::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale; + } + } + lean::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + lean::reduce_sum(scores, row_sum); + } + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, float rp_dropout = 1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), lean::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); +#pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + // if (threadIdx.x == 0 && blockIdx.z == 0) { + // printf("sum: %f, inv_sum: %f\n", sum, inv_sum); + // printf("mi %d row_max %f softmax_scale %f\n", mi, row_max(mi), softmax_scale); + // } + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + return lse; + }; +}; + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/static_switch.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/static_switch.h new file mode 100644 index 0000000000000..7873f67471d5d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/static_switch.h @@ -0,0 +1,109 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#ifdef FLASHATTENTION_DISABLE_DROPOUT +#define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_ALIBI +#define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K +#define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else +#define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL +#define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define LOCAL_SWITCH BOOL_SWITCH +#endif + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } \ + }() + +#define MAXSPLIT_SWITCH(MAXSPLITS, ...) \ + [&] { \ + if (MAXSPLITS <= 2) { \ + constexpr static int kMaxSplits = 2; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 4) { \ + constexpr static int kMaxSplits = 4; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 8) { \ + constexpr static int kMaxSplits = 8; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 16) { \ + constexpr static int kMaxSplits = 16; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 32) { \ + constexpr static int kMaxSplits = 32; \ + return __VA_ARGS__(); \ + } else if (MAXSPLITS <= 64) { \ + constexpr static int kMaxSplits = 64; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/onnxruntime/contrib_ops/cuda/bert/lean_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/lean_attention/utils.h new file mode 100644 index 0000000000000..c76849686d539 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/lean_attention/utils.h @@ -0,0 +1,411 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace onnxruntime { +namespace lean { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ uint32_t relu2(const uint32_t x); + +template <> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +__forceinline__ __device__ uint32_t convert_relu2(const float2 x); + +template <> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +template <> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ __forceinline__ T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +template <> +struct Allreduce<1> { + template + static __device__ __forceinline__ T run(T x, Operator& op) { + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm_rs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +template +__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void relu_(Tensor& tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); +#pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +__forceinline__ __device__ auto convert_type_relu(Tensor const& tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); +#pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = lean::convert_type(tensor); + lean::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, const int max_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } + // TD [2023-04-13]: Strange that the code below can cause race condition. + // I think it's because the copies are under an if statement. + // if (Is_even_K) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, _), D(_, m, _)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, _)); + // } + // } + // } else { // It's slightly faster in this case if iterate over K first + // #pragma unroll + // for (int k = 0; k < size<2>(S); ++k) { + // if (predicate_K(k)) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, k), D(_, m, k)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, k)); + // } + // } + // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN + // if (Clear_OOB_MN || Is_even_MN) { + // clear(D(_, _, k)); + // } else { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { + // clear(D(_, m, k)); + // } + // } + // } + // } + // } + // } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_w_min_idx(Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace lean +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 9c558900d1fdb..e2587d172af94 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -9,6 +9,7 @@ #include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/lean_attention/lean_api.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -54,6 +55,10 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); +#if USE_LEAN_ATTENTION + enable_lean_attention_ = sizeof(T) == 2 && kernel_options_->UseLeanAttention(); +#endif + disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention(); @@ -151,8 +156,64 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default; + typedef typename ToCudaType::MappedType CudaT; + AttentionData data; + +#if USE_LEAN_ATTENTION || USE_FLASH_ATTENTION + size_t softmax_lse_bytes = 0; + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; +#endif + +#if USE_LEAN_ATTENTION + // Lean attention only supports token-generation phase with sequence_length == 1. + bool use_lean_attention = enable_lean_attention_ && + parameters.sequence_length == 1 && + parameters.past_sequence_length > 0 && + nullptr == attention_bias && + nullptr == key_padding_mask && + parameters.head_size == parameters.v_head_size && + onnxruntime::lean::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.num_heads); + + size_t sync_flag_bytes = 0; + if (use_lean_attention) { + softmax_lse_bytes = onnxruntime::lean::get_softmax_lse_size(parameters.sequence_length, + parameters.batch_size, + parameters.num_heads); + + auto [num_splits, slse_accum_bytes, o_accum_bytes, sflag_bytes, griddimz, max_tiles_tb, hload_tbs, tiles_per_head] = onnxruntime::lean::get_num_splits_and_buffer_sizes( + parameters.batch_size, + parameters.sequence_length, + parameters.total_sequence_length, + parameters.num_heads, // q heads + parameters.num_heads, // kv heads + parameters.head_size, + device_prop.multiProcessorCount, + parameters.is_unidirectional); + + data.num_splits = static_cast(num_splits); + data.grid_dim_z = static_cast(griddimz); + data.max_tiles_per_tb = static_cast(max_tiles_tb); + data.high_load_tbs = static_cast(hload_tbs); + data.tiles_per_head = static_cast(tiles_per_head); + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + sync_flag_bytes = sflag_bytes; + kernel_type = AttentionKernelType::AttentionKernel_LeanAttention; + } + + auto lean_sync_flag_buffer = GetScratchBuffer(sync_flag_bytes, context->GetComputeStream()); + data.lean_sync_flag = reinterpret_cast(lean_sync_flag_buffer.get()); +#else + constexpr bool use_lean_attention = false; +#endif + #if USE_FLASH_ATTENTION - bool use_flash_attention = !disable_flash_attention_ && + bool use_flash_attention = kernel_type == AttentionKernelType::AttentionKernel_Default && + !disable_flash_attention_ && nullptr == attention_bias && nullptr == key_padding_mask && parameters.head_size == parameters.v_head_size && @@ -165,25 +226,35 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } + // Allocate buffers - size_t softmax_lse_accum_bytes = 0; - size_t out_accum_bytes = 0; if (use_flash_attention) { + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, + parameters.batch_size, + parameters.num_heads); + using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); - parameters.num_splits = static_cast(num_splits); + data.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; kernel_type = AttentionKernelType::AttentionKernel_FlashAttention; } - auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; - auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr - auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr +#endif + +#if USE_LEAN_ATTENTION || USE_FLASH_ATTENTION + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + if (use_flash_attention || use_lean_attention) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } #endif bool is_mask_none_or_1d_k_len = parameters.mask_type == AttentionMaskType::MASK_NONE || @@ -284,8 +355,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { kernel_type = AttentionKernelType::AttentionKernel_Unfused; } - typedef typename ToCudaType::MappedType CudaT; - AttentionData data; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); data.key = (nullptr == key) ? nullptr : reinterpret_cast(key->Data()); @@ -303,6 +372,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_flash_attention = use_flash_attention; + data.use_lean_attention = use_lean_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; data.kernel_type = kernel_type; data.allocator = Info().GetAllocator(OrtMemType::OrtMemTypeDefault); @@ -331,6 +401,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.total_sequence_length, fused_runner, use_flash_attention, + use_lean_attention, use_fused_cross_attention, use_memory_efficient_attention, use_cudnn_sdpa, @@ -342,16 +413,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.workspace_bytes = workspace_bytes; data.allow_debug_info = kernel_options_->AllowDebugInfo(); - if (softmax_lse_accum_buffer != nullptr) { - data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); - } - if (out_accum_buffer != nullptr) { - data.out_accum = reinterpret_cast(out_accum_buffer.get()); - } if (data.allow_debug_info) { AttentionKernelDebugInfo debug_info; debug_info.use_flash_attention = use_flash_attention; + debug_info.use_lean_attention = use_lean_attention; debug_info.use_cudnn_flash_attention = use_cudnn_sdpa; debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; debug_info.use_efficient_attention = use_memory_efficient_attention; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 8edc1d0e6ac06..b093b226c50b0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -32,6 +32,9 @@ class MultiHeadAttention final : public CudaKernel { bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; bool disable_flash_attention_; +#if USE_LEAN_ATTENTION + bool enable_lean_attention_; +#endif bool disable_memory_efficient_attention_; bool enable_cudnn_flash_attention_; diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 1b774b163888f..33cd906508bcf 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -179,6 +179,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_fused_cross_attention = false; constexpr bool use_memory_efficient_attention = false; constexpr bool use_flash_attention = false; + constexpr bool use_lean_attention = false; constexpr bool use_cudnn_flash_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, @@ -190,6 +191,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { parameters.total_sequence_length, fused_runner, use_flash_attention, + use_lean_attention, use_fused_cross_attention, use_memory_efficient_attention, use_cudnn_flash_attention, diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index d8acb66158ed2..d922f153b4b91 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -72,6 +72,7 @@ class SdpaKernel(IntEnum): TRT_FLASH_ATTENTION = 32 TRT_CROSS_ATTENTION = 64 TRT_CAUSAL_ATTENTION = 128 + LEAN_ATTENTION = 256 # Since we support attention bias, so we only need support up to 2D mask. @@ -598,8 +599,8 @@ def measure_latency(cuda_session: CudaSession, input_dict): return end - start -def flops(batch, sequence_length, head_size, num_heads, causal): - return 4 * batch * sequence_length**2 * num_heads * head_size // (2 if causal else 1) +def flops(batch, sequence_length_q, sequence_length_kv, head_size, num_heads, causal): + return 4 * batch * sequence_length_q * sequence_length_kv * num_heads * head_size // (2 if causal else 1) def tflops_per_second(flop, time): @@ -613,6 +614,7 @@ def get_gpu_kernel_name(attention_kernel: SdpaKernel) -> str: kernel_names = { SdpaKernel.DEFAULT: "ort:default", SdpaKernel.FLASH_ATTENTION: "ort:flash", + SdpaKernel.LEAN_ATTENTION: "ort:lean", SdpaKernel.EFFICIENT_ATTENTION: "ort:efficient", SdpaKernel.CUDNN_FLASH_ATTENTION: "ort:cudnn", SdpaKernel.MATH: "ort:math", @@ -808,16 +810,17 @@ def sdpa_kernel_from_debug_info( ): os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "1" captured_text = None + try: with CaptureStdout() as captured: session = create_session(config, sess_options, attention_kernel=attention_kernel) input_dict = config.random_inputs() session.infer(input_dict) - captured_text = captured.output.decode() + captured_text = captured.output.decode() except Exception as e: print(f"Failed to run {attention_kernel=} for {config=}. Exception: {e}") - finally: - os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0" + + os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0" if captured_text is not None: m = re.search("SdpaKernel=(?P[A-Z_]+)", captured_text) @@ -825,6 +828,7 @@ def sdpa_kernel_from_debug_info( name = m.group("kernel") kernel_names = { "FLASH_ATTENTION": "ort:flash", + "LEAN_ATTENTION": "ort:lean", "EFFICIENT_ATTENTION": "ort:efficient", "CUDNN_FLASH_ATTENTION": "ort:cudnn", "MATH": "ort:math", @@ -867,6 +871,15 @@ def run_tflops_test( SdpaKernel.CUDNN_FLASH_ATTENTION, SdpaKernel.MATH, ] + + if args.past_sequence_length > 0: + backends.append(SdpaKernel.LEAN_ATTENTION) + + if args.past_sequence_length > 0 and causal: + backends.remove(SdpaKernel.CUDNN_FLASH_ATTENTION) + + if args.past_sequence_length > 4096: + backends.remove(SdpaKernel.MATH) else: backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH] else: @@ -884,6 +897,8 @@ def run_tflops_test( for input_format in formats: for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: + if past_sequence_length > 0 and input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: + continue config = MultiHeadAttentionConfig( batch_size=batch_size, sequence_length=sequence_length, @@ -900,6 +915,7 @@ def run_tflops_test( dtype=torch.float16 if use_gpu else torch.float, share_past_present_buffer=False, input_format=input_format, + has_past_input=past_sequence_length > 0, has_attn_bias=args.has_attn_bias, broadcast_attn_bias_dim_0=args.broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1=args.broadcast_attn_bias_dim_1, @@ -926,11 +942,19 @@ def run_tflops_test( print(f"skip input_format for {vars(config)}") continue + if use_gpu and config.total_sequence_length > 8192: + if config.verbose: + print(f"skip large sequence length for {vars(config)}") + continue + if use_gpu: actual_kernel = sdpa_kernel_from_debug_info(config, attention_kernel, sess_options) if actual_kernel is None: print(f"Warning: skip {config} since kernel from debug info is None") continue + if actual_kernel != request_kernel and request_kernel != "ort:default": + print(f"Skip since {actual_kernel=} != {request_kernel=}") + continue else: # CPU has no debug info for now. actual_kernel = request_kernel @@ -956,11 +980,17 @@ def run_tflops_test( format_str = InputFormats.input_format_str(input_format) # compute TFLOPS per second - speed = None - if past_sequence_length == 0: - speed = tflops_per_second( - flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency - ) + speed = tflops_per_second( + flops( + batch_size, + sequence_length, + sequence_length + past_sequence_length, + head_size, + num_heads, + causal, + ), + average_latency, + ) row = { "use_gpu": use_gpu, @@ -983,11 +1013,11 @@ def run_tflops_test( } csv_writer.writerow(row) - speed = f"{speed:.2f}" if speed is not None else "NA" + speed = f"{speed:.3f}" if speed is not None else "NA" print( f"{format_str}\t{causal}\t{args.has_attn_bias}\t{batch_size}\t" f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t" - f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{actual_kernel}\t{request_kernel}" + f"{intra_op_num_threads}\t{average_latency * 1000:.3f}\t{speed}\t{actual_kernel}\t{request_kernel}" ) @@ -1055,7 +1085,17 @@ def run_torch_test( except RuntimeError: continue - speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency) + speed = tflops_per_second( + flops( + batch_size, + sequence_length, + sequence_length + past_sequence_length, + head_size, + num_heads, + causal, + ), + torch_latency, + ) input_format = "Q,K,V" print( f"{input_format}\t{causal}\t{False}\t{batch_size}\t" @@ -1090,7 +1130,8 @@ def run_tflops_tests(args): features += "_causal" if args.past_sequence_length > 0: features += "_past" - csv_filename = "benchmark_mha_{}_{}_{}.csv".format( + csv_filename = "{}_{}_{}_{}.csv".format( + args.csv_filename_prefix, features, "torch" if args.torch else "ort", datetime.now().strftime("%Y%m%d-%H%M%S"), @@ -1343,6 +1384,14 @@ def _parse_arguments(): ) parser.set_defaults(broadcast_attn_bias_dim_1=False) + parser.add_argument( + "--csv_filename_prefix", + required=False, + type=str, + default="benchmark_mha", + help="Prefix of csv filename", + ) + args = parser.parse_args() return args diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh index ff6dd16e698df..8d811219d4dac 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.sh +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -5,45 +5,104 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" +# Usage: benchmark_mha.sh [gpu|cpu|lean] +task="${1:-gpu}" -export CUDA_VISIBLE_DEVICES=0 -python benchmark_mha.py --use_gpu +# Function to lock GPU clocks and set power limit for a GPU +configure_gpu() { + local gpu_id=$1 -echo "Benchmark BERT-Large performance on GPU without attention bias" -python benchmark_mha.py --use_gpu -b 16 + # Ensure nvidia-smi is available + if ! command -v nvidia-smi &> /dev/null + then + echo "nvidia-smi not found. Please ensure NVIDIA drivers are installed." + exit + fi -echo "Benchmark BERT-Large performance on GPU with attention bias" -python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias -python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 -python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 --broadcast_attn_bias_dim_1 + # Enable Persistence Mode + sudo nvidia-smi -pm 1 -i $gpu_id -python benchmark_mha.py --use_gpu --use_cuda_graph -python benchmark_mha.py --use_gpu --torch + # Get the maximum clock speeds for graphics and memory. + nvidia-smi -q -d CLOCK -i ${gpu_id} | grep -A3 "Max Clocks" + max_graphics_clock=$(nvidia-smi -q -d CLOCK -i ${gpu_id} | grep -A1 "Max Clocks" | grep "Graphics" | awk '{print $3}') + max_memory_clock=$(nvidia-smi -q -d CLOCK -i ${gpu_id} | grep -A3 "Max Clocks" | grep "Memory" | awk '{print $3}') -cat benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv + # Lock the GPU clocks to maximum frequencies + sudo nvidia-smi -i $gpu_id --lock-gpu-clocks=$max_graphics_clock,$max_graphics_clock + sudo nvidia-smi -i $gpu_id --lock-memory-clocks=$max_memory_clock,$max_memory_clock -echo "Benchmark performance on CPU with number of threads:" -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=1 python benchmark_mha.py --torch -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=2 python benchmark_mha.py --torch -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=4 python benchmark_mha.py --torch -MKL_DYNAMIC=FALSE OMP_NUM_THREADS=8 python benchmark_mha.py --torch + nvidia-smi --query-gpu=clocks.gr,clocks.sm,clocks.mem --format=csv + echo "GPU $gpu_id clocks locked to $max_graphics_clock MHz (graphics) and $max_memory_clock MHz (memory)" -python benchmark_mha.py --intra_op_num_threads 1 -python benchmark_mha.py --intra_op_num_threads 2 -python benchmark_mha.py --intra_op_num_threads 4 -python benchmark_mha.py --intra_op_num_threads 8 + # Set Power Limit to maximum + power_limit=$(nvidia-smi --query-gpu=power.limit -i 0 --format=csv | grep "0" | awk '{print $1}') + power_limit=${power_limit%.*} + sudo nvidia-smi -pl $power_limit -i $gpu_id + export CUDA_VISIBLE_DEVICES=$gpu_id +} -echo "Benchmark performance on CPU with default threads settings:" -python benchmark_mha.py -ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py -python benchmark_mha.py --torch +run_gpu_benchmarks() { + echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" -python benchmark_mha.py --causal -python benchmark_mha.py --torch --causal + python benchmark_mha.py --use_gpu -# Pytorch SDPA does not support causal attention with past state, we only test ORT here. -python benchmark_mha.py --causal --has_past + echo "Benchmark BERT-Large performance on GPU without attention bias" + python benchmark_mha.py --use_gpu -b 16 -cat benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv + echo "Benchmark BERT-Large performance on GPU with attention bias" + python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias + python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 + python benchmark_mha.py --use_gpu -b 16 -r 1000 --has_attn_bias --broadcast_attn_bias_dim_0 --broadcast_attn_bias_dim_1 + + python benchmark_mha.py --use_gpu --use_cuda_graph + python benchmark_mha.py --use_gpu --torch + + cat benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv +} + +run_lean_benchmarks() { + echo "Benchmark long context decoding performance on GPU" + for b in 1 4 16; do + for s in 32 64 128 256 512 1024 2048 4096 8192 16384 32768 65536; do + python benchmark_mha.py --use_gpu --causal -b $b -s 1 -p $s -n 16 -d 64 -r 1000 --csv_filename_prefix benchmark_lean + python benchmark_mha.py --use_gpu --causal -b $b -s 1 -p $s -n 32 -d 128 -r 1000 --csv_filename_prefix benchmark_lean + done + done + cat benchmark_lean_*.csv > lean_benchmark_results.csv +} + +run_cpu_benchmarks() { + echo "Benchmark performance on CPU with number of threads:" + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=1 python benchmark_mha.py --torch + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=2 python benchmark_mha.py --torch + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=4 python benchmark_mha.py --torch + MKL_DYNAMIC=FALSE OMP_NUM_THREADS=8 python benchmark_mha.py --torch + + python benchmark_mha.py --intra_op_num_threads 1 + python benchmark_mha.py --intra_op_num_threads 2 + python benchmark_mha.py --intra_op_num_threads 4 + python benchmark_mha.py --intra_op_num_threads 8 + + + echo "Benchmark performance on CPU with default threads settings:" + python benchmark_mha.py + ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py + python benchmark_mha.py --torch + + python benchmark_mha.py --causal + python benchmark_mha.py --torch --causal + + # Pytorch SDPA does not support causal attention with past state, we only test ORT here. + python benchmark_mha.py --causal --has_past + + cat benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv +} + +[ "$task" != "cpu" ] && configure_gpu 0 + +[ "$task" == "gpu" ] && run_gpu_benchmarks + +[ "$task" == "cpu" ] && run_cpu_benchmarks + +[ "$task" == "lean" ] && run_lean_benchmarks diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 69f0035ef8a17..9e7c7378370c1 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -9,6 +9,7 @@ import concurrent.futures import itertools +import os import unittest from typing import Dict, List, Optional @@ -400,6 +401,49 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): yield config +def lean_attention_test_cases(provider: str, comprehensive: bool): + if provider == "CUDAExecutionProvider" and get_compute_capability() < 80: + return + yield + + batch_sizes = [1, 2, 3] if comprehensive else [1, 2] + sequence_lengths = [2, 15, 16, 255, 256, 512, 1024, 2048, 4096, 8192] if comprehensive else [2, 255, 512] + heads = [1, 4, 16] if comprehensive else [1, 4] + head_sizes = [64, 128] + device, dtype, formats = get_provider_support_info(provider, True) + mask_formats = [AttentionMaskFormat.Mask_None] + + sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory + for batch_size in batch_sizes: + for total_seq_len in sequence_lengths: + for num_heads in heads: + for head_size in head_sizes: + for format in formats: + for causal in get_causal_support(format): + for is_prompt in [False]: + for mask_format in mask_formats: + sequence_length = total_seq_len if is_prompt else 1 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=total_seq_len - sequence_length, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=True, + share_past_present_buffer=False, + input_format=format, + mask_format=mask_format, + ) + yield config + + def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): if provider == "CUDAExecutionProvider" and get_compute_capability() < 60: return @@ -787,6 +831,12 @@ def run_mha_cuda(self): for config in mha_test_cases("CUDAExecutionProvider", comprehensive_mode): parity_check_mha(config, rtol=5e-3, atol=5e-3) + def run_lean_attention(self): + os.environ["ORT_ENABLE_LEAN_ATTENTION"] = "1" + for config in lean_attention_test_cases("CUDAExecutionProvider", comprehensive_mode): + parity_check_mha(config, rtol=5e-3, atol=5e-3 if config.total_sequence_length <= 512 else 5e-2) + os.environ.pop("ORT_ENABLE_LEAN_ATTENTION", None) + def run_mha_cpu(self): for config in mha_test_cases("CPUExecutionProvider", comprehensive_mode): parity_check_mha(config, rtol=5e-3, atol=5e-3) @@ -842,6 +892,7 @@ def test_all(self): # Run tests sequentially to avoid out of memory issue. self.run_mha_cpu() self.run_mha_cuda() + self.run_lean_attention() self.run_mha_cuda_multi_threading_default() self.run_mha_cuda_multi_threading_cudnn() self.run_mha_cuda_multi_threading_efficient()