From 509cb54d6f9b38ce21b897317a03e2fc61555dd0 Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Fri, 30 Aug 2024 19:11:04 -0700 Subject: [PATCH] softcap gqa (#21683) ### Description Implement softcap for gqa. ### Motivation and Context Fixes certain models like Gemma-2 which need softcap to work so they don't output nan's. --- cmake/patches/cutlass/cutlass_3.5.0.patch | 40 ++++- docs/ContribOperators.md | 2 + .../contrib_ops/cpu/bert/attention_common.h | 1 + .../contrib_ops/cpu/bert/attention_helper.h | 11 ++ .../contrib_ops/cpu/bert/gqa_attention_base.h | 11 +- .../cpu/bert/group_query_attention.cc | 4 +- .../cpu/bert/group_query_attention_helper.h | 7 +- .../contrib_ops/cuda/bert/attention_impl.cu | 6 +- .../bert/cutlass_fmha/fmha_launch_template.h | 1 + .../cutlass_fmha/memory_efficient_attention.h | 1 + .../cuda/bert/flash_attention/flash.h | 1 + .../cuda/bert/flash_attention/flash_api.cc | 21 ++- .../cuda/bert/flash_attention/flash_api.h | 3 + .../bert/flash_attention/flash_fwd_kernel.h | 24 ++- .../flash_fwd_launch_template.h | 72 +++++---- .../cuda/bert/flash_attention/static_switch.h | 10 ++ .../cuda/bert/flash_attention/utils.h | 10 ++ .../cuda/bert/group_query_attention.cc | 2 + .../cuda/bert/group_query_attention.h | 1 + .../cuda/bert/group_query_attention_helper.h | 7 +- .../cuda/bert/group_query_attention_impl.cu | 3 +- .../bert/packed_multihead_attention_impl.cu | 1 + .../core/graph/contrib_ops/bert_defs.cc | 4 + .../transformers/test_flash_attn_cuda.py | 134 ++++++++++------ .../transformers/test_flash_attn_rocm.py | 4 +- .../test/python/transformers/test_gqa_cpu.py | 145 ++++++++++++------ 26 files changed, 366 insertions(+), 160 deletions(-) diff --git a/cmake/patches/cutlass/cutlass_3.5.0.patch b/cmake/patches/cutlass/cutlass_3.5.0.patch index 93b8c474af9ed..a02a745b612ae 100644 --- a/cmake/patches/cutlass/cutlass_3.5.0.patch +++ b/cmake/patches/cutlass/cutlass_3.5.0.patch @@ -1,8 +1,16 @@ diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h -index 4c80f549..34327633 100644 +index 4c80f549..5ad610c8 100644 --- a/examples/41_fused_multi_head_attention/kernel_forward.h +++ b/examples/41_fused_multi_head_attention/kernel_forward.h -@@ -221,6 +221,8 @@ struct AttentionKernel { +@@ -189,6 +189,7 @@ struct AttentionKernel { + + // Scale + accum_t scale = 0.0; ++ accum_t softcap = 0.0; + + // Dimensions/strides + int32_t head_dim = 0; +@@ -221,6 +222,8 @@ struct AttentionKernel { int32_t num_batches = 0; int32_t num_heads = 0; @@ -11,7 +19,23 @@ index 4c80f549..34327633 100644 // dropout bool use_dropout = false; unsigned long long dropout_batch_head_rng_offset = 0; -@@ -897,7 +899,8 @@ struct AttentionKernel { +@@ -818,6 +821,15 @@ struct AttentionKernel { + accum = + cutlass::multiplies()(p.scale, accum); + } ++ ++ // apply softcap if applicable ++ if (p.softcap > 0.0) { ++ accum = cutlass::multiplies()(1.0 / p.softcap, accum); ++ for (int i = 0; i < accum.size(); ++i) { ++ accum[i] = cutlass::fast_tanh(accum[i]); ++ } ++ accum = cutlass::multiplies()(p.softcap, accum); ++ } + + // apply attention bias if applicable + if (kSupportsBias && p.attn_bias_ptr != nullptr) { +@@ -897,7 +909,8 @@ struct AttentionKernel { p.num_keys - iter_key_start, iter_key_start == 0, iteratorC_tile_offset, @@ -21,7 +45,7 @@ index 4c80f549..34327633 100644 // Output results to shared-memory int warp_idx_mn_0 = my_warp_id % -@@ -1166,7 +1169,8 @@ struct AttentionKernel { +@@ -1166,7 +1179,8 @@ struct AttentionKernel { int max_col, bool is_first, typename WarpIteratorC::TensorCoord const& tile_offset, @@ -31,7 +55,7 @@ index 4c80f549..34327633 100644 /* Iterates on the accumulator and corresponding position on result matrix (1) Update `mi[r]` to the max value of the row `r` -@@ -1257,7 +1261,7 @@ struct AttentionKernel { +@@ -1257,7 +1271,7 @@ struct AttentionKernel { accum_t mi_row, total_row; LambdaIterator::iterateRows( lane_offset, @@ -40,7 +64,7 @@ index 4c80f549..34327633 100644 [&](int accum_m, int accum_n, int idx) { frag[idx] = (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); -@@ -1294,7 +1298,7 @@ struct AttentionKernel { +@@ -1294,7 +1308,7 @@ struct AttentionKernel { for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { total_row += addition_storage[id + kQueriesPerBlock * i]; } @@ -50,7 +74,7 @@ index 4c80f549..34327633 100644 } diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h -index 964d2ff3..b366bc14 100644 +index 964d2ff3..676ba768 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -39,6 +39,7 @@ @@ -73,4 +97,4 @@ index 964d2ff3..b366bc14 100644 +#endif #else return half_t(1.f / std::sqrt(half_t::convert(lhs))); - #endif \ No newline at end of file + #endif diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 8a13505fe0fc7..aadf4ebe2f488 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2543,6 +2543,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
smooth_softmax : int
Use a smooth factor in softmax.
+
softcap : float
+
Softcap value for attention weights. Default value is 0.
#### Inputs (7 - 9) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 516ef57d8cd18..45acb90ba68b0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -119,6 +119,7 @@ struct GroupQueryAttentionParameters { bool rotary_interleaved; bool use_smooth_softmax; float scale; + float softcap; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; int zeros_count; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 04e120863d39e..e6c948acb0d6c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -102,6 +102,17 @@ inline void ComputeAttentionSoftmaxInplace(float* score, int N, int D, ThreadPoo MlasComputeSoftmax(score, score, N, D, false, false, tp); } +template +void ComputeAttentionSoftcapInplace(T* scores, int sequence_length, float softcap) { + for (int i = 0; i < sequence_length; i++) { + scores[i] = scores[i] / softcap; + scores[i] = std::tanh(scores[i]); + scores[i] = scores[i] * softcap; + } +} + +template void ComputeAttentionSoftcapInplace(float* scores, int sequence_length, float softcap); + template void PrepareMask(const int32_t* mask_index, gsl::span mask_index_dims, diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 70f8564a2cbf2..2bf0aa0915c2d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -26,6 +26,7 @@ class GQAAttentionBase { kv_num_heads_ = static_cast(kv_num_heads); scale_ = info.GetAttrOrDefault("scale", 0.0f); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; @@ -38,7 +39,8 @@ class GQAAttentionBase { int num_heads_; // number of attention heads of Q int kv_num_heads_; // number of attention heads of K or V float scale_; // the scaling factor applied before softmax - bool do_rotary_; // whether or not to use rotary embeddings + float softcap_; + bool do_rotary_; // whether or not to use rotary embeddings bool rotary_interleaved_; int local_window_size_; @@ -199,6 +201,10 @@ class GQAAttentionBase { for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } + if (softcap_ > 0.f) { + ComputeAttentionSoftcapInplace(output_softmax + seq_causal_length - local_window_size_ - 1, + local_window_size_ + 1, softcap_); + } if (use_smooth_softmax_) { ComputeSmoothSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, local_window_size_ + 1, nullptr); @@ -207,6 +213,9 @@ class GQAAttentionBase { local_window_size_ + 1, nullptr); } } else { + if (softcap_ > 0.f) { + ComputeAttentionSoftcapInplace(output_softmax, seq_causal_length, softcap_); + } if (use_smooth_softmax_) { ComputeSmoothSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); } else { diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 97388a9d6bce8..87675255f5ba4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -50,7 +50,6 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* sin_cache = context->Input(8); GroupQueryAttentionParameters parameters = {}; - constexpr float scale = 1.0f; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, key, value, @@ -63,7 +62,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { kv_num_heads_, seqlens_k, total_seqlen, - scale)); + scale_, + softcap_)); const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 7ffb72fe55d25..3342052260ff9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -23,7 +23,8 @@ Status CheckInputs(const Tensor* query, int kv_num_heads, const Tensor* seqlens_k, const Tensor* total_seqlen, - float scale) { + float scale, + float softcap) { // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache // past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr // past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr @@ -236,6 +237,7 @@ Status CheckInputs(const Tensor* query, output_parameters->is_unidirectional = true; output_parameters->is_prompt = is_prompt; output_parameters->scale = scale; + output_parameters->softcap = softcap; output_parameters->qkv_format = qkv_format; output_parameters->past_kv_format = past_kv_format; } @@ -256,12 +258,13 @@ Status CheckInputs(const Tensor* query, const Tensor* seqlens_k, const Tensor* total_seqlen, float scale, + float softcap, int max_threads_per_block) { if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale); + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap); } } // namespace group_query_attention_helper } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 3af3751ba0e51..eff58c0080012 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -303,9 +303,9 @@ Status FlashAttention( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, false, - parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), - data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH)); + parameters.sequence_length, parameters.total_sequence_length, scale, 0.0, parameters.is_unidirectional, is_bf16, + false, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 5ffa63c54c8fb..a10d2548fa7b8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -169,6 +169,7 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.head_dim_value = params.v_head_size; p.scale = params.scale; + p.softcap = params.softcap; // When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel p.num_queries = params.sequence_length; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index ec2c92c437283..9fe66c6fe992e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -28,6 +28,7 @@ struct MemoryEfficientAttentionParams { bool use_smooth_softmax; float scale; + float softcap = 0.0; int32_t* seqstart_q_ptr; int32_t* seqstart_k_ptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index bcd87c1ab6251..4aa633ca45e2b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -74,6 +74,7 @@ struct Flash_fwd_params : public Qkv_params { // The scaling factors for the kernel. float scale_softmax = 0.0; float scale_softmax_log2 = 0.0; + float softcap = 0.0; // array of length b+1 holding starting offset of each sequence. int* __restrict__ cu_seqlens_q = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index f875d31f5ca7a..6a3e52bee3995 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -35,6 +35,7 @@ void set_params_fprop(Flash_fwd_params& params, void* p_d, void* softmax_lse_d, float softmax_scale, + float softcap, bool is_causal, bool is_bf16, bool use_smooth_softmax, @@ -111,8 +112,16 @@ void set_params_fprop(Flash_fwd_params& params, params.d_rounded = head_size_rounded; // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + } else { + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API separates // local and causal, meaning when we have local window size @@ -267,6 +276,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, int seqlen_q, int seqlen_k, float softmax_scale, + const float softcap, bool is_causal, bool is_bf16, bool use_smooth_softmax, @@ -294,6 +304,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, /*p_ptr=*/nullptr, softmax_lse, softmax_scale, + softcap, is_causal, is_bf16, use_smooth_softmax, @@ -343,6 +354,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, int max_seqlen_q, int max_seqlen_k, float softmax_scale, + const float softcap, bool is_causal, bool is_bf16, int max_num_blocks_per_seq, @@ -367,6 +379,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, /*p_ptr=*/nullptr, softmax_lse, softmax_scale, + softcap, is_causal, is_bf16, false, @@ -427,6 +440,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int seqlen_k_new, int rotary_dim, const float softmax_scale, + const float softcap, bool is_causal, bool is_bf16, bool use_smooth_softmax, @@ -440,7 +454,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int max_num_blocks_per_seq, int page_block_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, 32); + const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); const bool paged_KV = block_table != nullptr; @@ -460,6 +474,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, /*p_ptr=*/nullptr, softmax_lse, softmax_scale, + softcap, is_causal, is_bf16, use_smooth_softmax, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index baad0a938d377..57752e8237d6e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -50,6 +50,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, int seqlen_q, int seqlen_k, float softmax_scale, + const float softcap, bool is_causal, bool is_bf16, bool use_smooth_softmax, @@ -77,6 +78,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, int max_seqlen_q, int max_seqlen_k, float softmax_scale, + const float softcap, bool is_causal, bool is_bf16, int max_num_blocks_per_seq = 0, @@ -104,6 +106,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int seqlen_k_new, int rotary_dim, const float softmax_scale, + const float softcap, bool is_causal, bool is_bf16, bool use_smooth_softmax, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index b2aa3668a5be1..e961bab399326 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -34,7 +34,7 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -278,6 +278,9 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K); // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap) { + flash::apply_softcap(acc_s, params.softcap); + } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); @@ -322,6 +325,9 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K); + if constexpr (Is_softcap) { + flash::apply_softcap(acc_s, params.softcap); + } flash::cp_async_wait<0>(); __syncthreads(); @@ -418,7 +424,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { @@ -799,6 +805,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K); // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap) { + flash::apply_softcap(acc_s, params.softcap); + } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); @@ -868,6 +877,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K); + if constexpr (Is_softcap) { + flash::apply_softcap(acc_s, params.softcap); + } flash::cp_async_wait<0>(); __syncthreads(); @@ -979,7 +991,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -995,12 +1007,12 @@ inline __device__ void compute_attn(const Params& params) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1009,7 +1021,7 @@ inline __device__ void compute_attn_splitkv(const Params& params) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index b1941df75be2c..d8465d54e6be8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -26,18 +26,18 @@ namespace flash { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // Enforce constraints - flash::compute_attn(params); + flash::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { #if defined(ARCH_SUPPORTS_FLASH) - flash::compute_attn_splitkv(params); + flash::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -64,23 +64,25 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(smem_size)); - // ORT_ENFORCE(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<(smem_size), stream>>>(params); + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(smem_size)); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<(smem_size), stream>>>(params); + }); }); }); }); @@ -103,19 +105,21 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { BOOL_SWITCH(params.num_splits > 1, SplitConst, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV_Const, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - // If Append_KV_Const, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If Is_Local_Const, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, Is_Local_Const && !Is_causal, Has_alibi, - IsEvenMNConst && !Append_KV_Const && IsEvenKConst && !Is_Local_Const && Kernel_traits::kHeadDim <= 128, - IsEvenKConst, SplitConst, Append_KV_Const >; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(smem_size)); - } - kernel<<(smem_size), stream>>>(params); + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // If Append_KV_Const, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_Local_Const, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, Is_Local_Const && !Is_causal, Has_alibi, + IsEvenMNConst && !Append_KV_Const && IsEvenKConst && !Is_Local_Const && Kernel_traits::kHeadDim <= 128, + IsEvenKConst, Is_softcap, SplitConst, Append_KV_Const >; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(smem_size)); + } + kernel<<(smem_size), stream>>>(params); + }); }); }); }); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h index 02bd7effd7da6..d8124eb032b32 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h @@ -44,6 +44,16 @@ #define EVENK_SWITCH BOOL_SWITCH #endif +#ifdef FLASHATTENTION_DISABLE_SOFTCAP +#define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else +#define SOFTCAP_SWITCH BOOL_SWITCH +#endif + #ifdef FLASHATTENTION_DISABLE_LOCAL #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h index 9ef75120881e4..76b1aaefebeff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -369,5 +369,15 @@ inline __device__ void copy_w_min_idx(Tensor const& S, //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ void apply_softcap(Tensor& tensor, const float softcap) { +#pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 58d1d7f0e4af4..d0ae812bb4fa2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -51,6 +51,7 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; kernel_options_ = this->GetAttentionKernelOptions(); @@ -96,6 +97,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { total_seqlen, is_past_bsnh_, scale_, + softcap_, device_prop.maxThreadsPerBlock)); parameters.local_window_size = local_window_size_; parameters.is_unidirectional = is_unidirectional_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 872fe9fe05ad2..08457feb099b3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -30,6 +30,7 @@ class GroupQueryAttention final : public CudaKernel { bool rotary_interleaved_; bool use_smooth_softmax_; float scale_; + float softcap_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 39efdfd66bcc6..e65827e4ccdd5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -24,7 +24,8 @@ Status CheckInputs(const Tensor* query, const Tensor* seqlens_k, const Tensor* total_seqlen, bool is_past_bsnh, - float scale) { + float scale, + float softcap) { // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr @@ -261,6 +262,7 @@ Status CheckInputs(const Tensor* query, output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_prompt = is_prompt; output_parameters->scale = scale; + output_parameters->softcap = softcap; output_parameters->qkv_format = qkv_format; output_parameters->past_kv_format = past_kv_format; } @@ -282,12 +284,13 @@ Status CheckInputs(const Tensor* query, const Tensor* total_seqlen, bool is_past_bsnh, float scale, + float softcap, int max_threads_per_block) { if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale, softcap); } } // namespace group_query_attention_helper diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 04aa1c14a0f69..be94f26ec298f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -678,7 +678,7 @@ Status FlashAttention( reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr, batch_size, num_heads, kv_num_heads, head_size, sequence_length, parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, - scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, + scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv)); @@ -829,6 +829,7 @@ Status EfficientAttention( p.v_head_size = head_size; p.causal = true; p.scale = scale; + p.softcap = parameters.softcap; p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient p.seqstart_q_ptr = nullptr; p.seqstart_k_ptr = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 9bb93b6d06167..f3b9fd310f46f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -639,6 +639,7 @@ Status FlashAttention( sequence_length, sequence_length, scale, + 0.0, false, // is causal false // is bf16 )); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index dd3a06e3eb4ba..5185205f1dde1 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1061,6 +1061,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("softcap", + "Softcap value for attention weights. Default value is 0.", + AttributeProto::FLOAT, + OPTIONAL_VALUE) .Attr("local_window_size", "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", AttributeProto::INT, diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 13bf51f74389a..c04929a3b603e 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -223,6 +223,7 @@ def create_group_query_attention_graph_prompt( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 @@ -248,6 +249,7 @@ def create_group_query_attention_graph_prompt( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, @@ -411,6 +413,7 @@ def create_group_query_attention_graph_past( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, ): past_kv_seqlen = config.kv_sequence_length @@ -438,6 +441,7 @@ def create_group_query_attention_graph_past( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, @@ -788,6 +792,7 @@ def gqa_prompt_func( past_kv_format=Formats.BSNH, share_buffer=True, rotary_interleaved=False, + softcap=0.0, use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_prompt( @@ -798,6 +803,7 @@ def gqa_prompt_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) @@ -895,6 +901,7 @@ def gqa_past_func( share_buffer=True, window_size=-1, rotary_interleaved=False, + softcap=0.0, use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_past( @@ -905,6 +912,7 @@ def gqa_past_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) @@ -1040,6 +1048,7 @@ def attention_ref( dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, upcast=True, reorder_ops=False, use_smooth_softmax=False, @@ -1077,6 +1086,10 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) else: scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores = scores / softcap + scores = scores.tanh() + scores = scores * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if window_size[0] >= 0 or window_size[1] >= 0: @@ -1215,6 +1228,7 @@ def parity_check_gqa_prompt( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, rtol=1e-3, atol=1e-3, @@ -1339,6 +1353,7 @@ def parity_check_gqa_prompt( None, causal=True, window_size=window_size, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() @@ -1363,6 +1378,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) else: @@ -1380,6 +1396,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) @@ -1388,7 +1405,7 @@ def parity_check_gqa_prompt( err_msg = ( f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}" + f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" ) # Make sure past-present buffer updating correctly numpy.testing.assert_allclose( @@ -1409,6 +1426,7 @@ def parity_check_gqa_prompt_no_buff( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, rtol=1e-3, atol=1e-3, @@ -1510,6 +1528,7 @@ def parity_check_gqa_prompt_no_buff( None, causal=True, window_size=window_size, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() @@ -1534,6 +1553,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) else: @@ -1551,6 +1571,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) @@ -1559,8 +1580,7 @@ def parity_check_gqa_prompt_no_buff( err_msg = ( f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}," - f" use_smooth_softmax={use_smooth_softmax}" + f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}, use_smooth_softmax={use_smooth_softmax}" ) # Make sure past-present buffer updating correctly numpy.testing.assert_allclose( @@ -1581,6 +1601,7 @@ def parity_check_gqa_past( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, rtol=1e-3, atol=1e-3, @@ -1701,6 +1722,7 @@ def parity_check_gqa_past( None, causal=True, window_size=window_size, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() @@ -1725,6 +1747,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) else: @@ -1742,6 +1765,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) @@ -1750,7 +1774,7 @@ def parity_check_gqa_past( err_msg = ( f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}" + f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" ) # Make sure past-present buffer updating correctly numpy.testing.assert_allclose( @@ -1771,6 +1795,7 @@ def parity_check_gqa_past_no_buff( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, rtol=1e-3, atol=1e-3, @@ -1897,6 +1922,7 @@ def parity_check_gqa_past_no_buff( None, causal=True, window_size=window_size, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() @@ -1921,6 +1947,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) else: @@ -1938,6 +1965,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) @@ -1946,7 +1974,7 @@ def parity_check_gqa_past_no_buff( err_msg = ( f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}" + f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" ) numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) @@ -2060,14 +2088,16 @@ def gqa_no_past_memory_efficient_test_cases(): for h in h_sizes: for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - yield ( - str(config) + f"{rotary}_{rotary_interleaved}_{packed}", - config, - rotary, - rotary_interleaved, - packed, - ) + for softcap in [0.0, 50.0]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + yield ( + str(config) + f"{rotary}_{rotary_interleaved}_{packed}", + config, + rotary, + rotary_interleaved, + packed, + softcap, + ) def gqa_no_past_flash_attention_test_cases(): @@ -2100,15 +2130,17 @@ def gqa_no_past_flash_attention_test_cases(): for local in [False, True]: for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", - config, - local, - rotary, - rotary_interleaved, - packed, - ) + for softcap in [0.0, 50.0]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + yield ( + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", + config, + local, + rotary, + rotary_interleaved, + packed, + softcap, + ) def gqa_past_memory_efficient_test_cases(): @@ -2140,15 +2172,17 @@ def gqa_past_memory_efficient_test_cases(): for h in h_sizes: for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - yield ( - str(config) + f"{rotary}_{rotary_interleaved}_{packed}", - config, - rotary, - rotary_interleaved, - packed, - ) + for softcap in [0.0, 50.0]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + yield ( + str(config) + f"{rotary}_{rotary_interleaved}_{packed}", + config, + rotary, + rotary_interleaved, + packed, + softcap, + ) def gqa_past_flash_attention_test_cases(): @@ -2181,21 +2215,23 @@ def gqa_past_flash_attention_test_cases(): for local in [False, True]: for rotary, rotary_interleaved in rotary_options_for_current_os(): for packed in [False, True]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", - config, - local, - rotary, - rotary_interleaved, - packed, - ) + for softcap in [0.0, 50.0]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + yield ( + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", + config, + local, + rotary, + rotary_interleaved, + packed, + softcap, + ) class TestGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) - def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed): + def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): if not has_memory_efficient(): return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" @@ -2209,6 +2245,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, use_smooth_softmax=False, ) parity_check_gqa_prompt_no_buff( @@ -2219,11 +2256,12 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, use_smooth_softmax=True, ) @parameterized.expand(gqa_no_past_flash_attention_test_cases()) - def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): if not has_flash_attention(): return print("------- FLASH ATTENTION (PROMPT CASE) --------") @@ -2236,6 +2274,7 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, use_smooth_softmax=True, ) parity_check_gqa_prompt_no_buff( @@ -2245,11 +2284,12 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, use_smooth_softmax=False, ) @parameterized.expand(gqa_past_memory_efficient_test_cases()) - def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed): + def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): if not has_memory_efficient(): return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" @@ -2263,6 +2303,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, use_smooth_softmax=True, ) parity_check_gqa_past_no_buff( @@ -2273,11 +2314,12 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, use_smooth_softmax=False, ) @parameterized.expand(gqa_past_flash_attention_test_cases()) - def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): if not has_flash_attention(): return print("------- FLASH ATTENTION (TOKEN GEN) -------") @@ -2292,6 +2334,7 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, use_smooth_softmax=False, ) parity_check_gqa_past_no_buff( @@ -2303,6 +2346,7 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + softcap=softcap, use_smooth_softmax=True, ) diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py index 880f4175e00b7..99460722c2469 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -18,7 +18,7 @@ class TestGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_flash_attention_test_cases()) - def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): config.ep = "ROCMExecutionProvider" if not torch.cuda.is_available(): return @@ -50,7 +50,7 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte ) @parameterized.expand(gqa_past_flash_attention_test_cases()) - def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): config.ep = "ROCMExecutionProvider" if not torch.cuda.is_available(): return diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index eeba0baccf15b..cc9d7ff51a5c6 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -145,6 +145,7 @@ def create_group_query_attention_graph_prompt( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 @@ -170,6 +171,7 @@ def create_group_query_attention_graph_prompt( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, @@ -333,6 +335,7 @@ def create_group_query_attention_graph_past( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, ): past_kv_seqlen = config.kv_sequence_length @@ -360,6 +363,7 @@ def create_group_query_attention_graph_past( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, @@ -671,6 +675,7 @@ def gqa_prompt_func( past_kv_format=Formats.BSNH, share_buffer=True, rotary_interleaved=False, + softcap=0.0, use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_prompt( @@ -681,6 +686,7 @@ def gqa_prompt_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) @@ -779,6 +785,7 @@ def gqa_past_func( share_buffer=True, window_size=-1, rotary_interleaved=False, + softcap=0.0, use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_past( @@ -789,6 +796,7 @@ def gqa_past_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) @@ -931,6 +939,7 @@ def attention_ref( dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, upcast=True, reorder_ops=False, use_smooth_softmax=False, @@ -969,6 +978,10 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) else: scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores = scores / softcap + scores = scores.tanh() + scores = scores * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if window_size[0] >= 0 or window_size[1] >= 0: @@ -1039,6 +1052,7 @@ def parity_check_gqa_prompt( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, rtol=1e-3, atol=1e-3, @@ -1149,6 +1163,7 @@ def parity_check_gqa_prompt( None, causal=True, window_size=window_size, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() @@ -1173,6 +1188,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) else: @@ -1190,6 +1206,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) @@ -1215,6 +1232,8 @@ def parity_check_gqa_prompt( rotary, " rotary_interleaved:", rotary_interleaved, + " softcap:", + softcap, " smooth_softmax:", use_smooth_softmax, "past kv format:", @@ -1246,6 +1265,7 @@ def parity_check_gqa_prompt_no_buff( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, rtol=1e-3, atol=1e-3, @@ -1330,6 +1350,7 @@ def parity_check_gqa_prompt_no_buff( None, causal=True, window_size=window_size, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() @@ -1354,6 +1375,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) else: @@ -1371,6 +1393,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) @@ -1396,6 +1419,8 @@ def parity_check_gqa_prompt_no_buff( rotary, " rotary_interleaved:", rotary_interleaved, + " softcap:", + softcap, " smooth_softmax:", use_smooth_softmax, "past kv format:", @@ -1427,6 +1452,7 @@ def parity_check_gqa_past( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, rtol=1e-3, atol=1e-3, @@ -1542,6 +1568,7 @@ def parity_check_gqa_past( None, causal=True, window_size=window_size, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() @@ -1566,6 +1593,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) else: @@ -1583,6 +1611,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + softcap, use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) @@ -1610,6 +1639,8 @@ def parity_check_gqa_past( rotary, " rotary_interleaved:", rotary_interleaved, + " softcap:", + softcap, " smooth_softmax:", use_smooth_softmax, " B:", @@ -1639,6 +1670,7 @@ def parity_check_gqa_past_no_buff( rotary=False, rotary_interleaved=False, packed=False, + softcap=0.0, use_smooth_softmax=False, rtol=1e-3, atol=1e-3, @@ -1760,6 +1792,7 @@ def parity_check_gqa_past_no_buff( None, causal=True, window_size=window_size, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() @@ -1784,6 +1817,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) else: @@ -1801,6 +1835,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + softcap=softcap, use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) @@ -1822,6 +1857,8 @@ def parity_check_gqa_past_no_buff( rotary, " rotary_interleaved:", rotary_interleaved, + "softcap", + softcap, " smooth_softmax:", use_smooth_softmax, "past kv format:", @@ -1874,29 +1911,32 @@ def test_gqa_no_past(self): for local in [False, True]: for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: for packed in [False, True]: - for use_smooth_softmax in [False, True]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_prompt( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_prompt_no_buff( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) + for softcap in [0.0, 50.0]: + for use_smooth_softmax in [False, True]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) def test_gqa_past(self): print("-------- TEST GQA PAST (TOKEN GEN) ---------") @@ -1928,34 +1968,37 @@ def test_gqa_past(self): for local in [False, True]: for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: for packed in [False, True]: - for use_smooth_softmax in [False, True]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - use_smooth_softmax=use_smooth_softmax, - ) - self.assertTrue(all_close) + for softcap in [0.0, 50.0]: + for use_smooth_softmax in [False, True]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) if __name__ == "__main__":