From 6e5757698870d029afaeeb50663e8326dddf0390 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 26 Aug 2024 23:13:15 -0700 Subject: [PATCH] Support Smooth Softmax in GroupQueryAttention (#21867) ### Description Softmax (formula 1) is like the following: ```math y_{i} = \frac{exp(x_{i})}{\sum_{i} exp(x_{i})} ``` After applying softmax, each element will be in the range of $(0, 1)$, and the elements will add up to 1, so that they can be interpreted as probabilities. However, in language model, softmax has two issues: * When all elements are -inf (for example, a whole row is masked when a query token is padding), the result is not defined since exp(-inf)=0 and divided-by-zero is encountered in the above formula. * Why do we need normalize in a way that each query word are treated as equal important (each row has sum equals to1)? **Smooth Softmax** (formula 2) is a modified version that introduces a smooth factor like the following: ```math s_{i} = \frac{exp(x_{i})}{1+ \sum_{i} exp(x_{i})} ``` This formula could tackle the above two issues: * It could handle the special case that all elements are -inf: the result $s_{i}$ is 0 for every element in such case. * Sum of all elements $\sum_{i}{s_{i}} = \frac{\sum_{i}{exp(x_{i})}}{1+ \sum_{i} exp(x_{i})}$ is in the range of (0, 1), so that we can train the model to assign different importance to different query words. Since exponential is prone to overflow or underflow, to get stable result, formula 3 can be used: ```math s_{i} = \frac{exp(x_{i} + c)}{exp(c)+ \sum_{i} exp(x_{i} +c)} ``` c can be any value in theory. In practical, choice of constant c shall avoid $exp(c)$ and $exp(x_{i} +c)$ overflow (or underflow) at the same time. A reasonable choice is like formula 4: ```math c=-\max_{i} \{ x_i \} ``` or apply a constraint that c <=0 like the following formula 5: ```math c=-\max(0, \max_{i} \{ x_i \}) ``` The latter one (formula 5) ensures that $s_{i}$ will fallback to formula 2 when all elements are negative. For CPU provider, smooth softmax is implemented in MLAS. CPU implementation uses formula 5. @wangyems implemented the smooth softmax in flash attention for CUDA, which requires Ampere or newer GPU. The implementation of smooth softmax in flash attention uses formula 4. --------- Co-authored-by: Ye Wang --- docs/ContribOperators.md | 2 + .../contrib_ops/cpu/bert/attention_common.h | 1 + .../contrib_ops/cpu/bert/attention_helper.h | 43 +++- .../contrib_ops/cpu/bert/gqa_attention_base.h | 19 +- .../contrib_ops/cuda/bert/attention_impl.cu | 2 +- .../cuda/bert/flash_attention/flash.h | 2 + .../cuda/bert/flash_attention/flash_api.cc | 7 + .../cuda/bert/flash_attention/flash_api.h | 2 + .../bert/flash_attention/flash_fwd_kernel.h | 4 +- .../cuda/bert/flash_attention/softmax.h | 4 +- .../cuda/bert/group_query_attention.cc | 3 + .../cuda/bert/group_query_attention.h | 1 + .../cuda/bert/group_query_attention_impl.cu | 6 +- .../core/graph/contrib_ops/bert_defs.cc | 4 + onnxruntime/core/mlas/inc/mlas.h | 1 + onnxruntime/core/mlas/lib/compute.cpp | 93 +++------ .../core/providers/cpu/math/softmax_shared.cc | 2 +- onnxruntime/core/providers/cpu/ml/ml_common.h | 2 +- .../test/mlas/bench/bench_computesoftmax.cpp | 4 +- .../test/mlas/unittest/test_softmax.cpp | 22 +- .../test/python/transformers/benchmark_gqa.py | 51 +++-- .../transformers/benchmark_gqa_windows.py | 18 +- .../transformers/test_flash_attn_cuda.py | 94 ++++++++- .../test/python/transformers/test_gqa_cpu.py | 195 +++++++++++++----- .../transformers/test_sparse_attention.py | 14 +- 25 files changed, 435 insertions(+), 161 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 0048190f9063b..33d872254a255 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2541,6 +2541,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
smooth_softmax : int
+
Use a smooth factor in softmax.
#### Inputs (7 - 9) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 1e01aa765ca6d..9e6671c26cf59 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -116,6 +116,7 @@ struct GroupQueryAttentionParameters { bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor bool do_rotary; bool rotary_interleaved; + bool use_smooth_softmax; float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 29ae769ed89f1..04e120863d39e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -16,6 +16,47 @@ using onnxruntime::concurrency::ThreadPool; namespace onnxruntime { namespace contrib { +template +void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { + ThreadPool::TryParallelFor(tp, N, D * 2.0, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t j = begin; j != end; ++j) { + float* x = reinterpret_cast(score) + j * D; + float* y = x; + + float max = -std::numeric_limits::infinity(); + for (int i = 0; i < D; i++) { + if (max < x[i]) + max = x[i]; + } + + if (max < 0.0f) { + max = 0.0f; + } + + for (int i = 0; i < D; i++) { + y[i] = expf(x[i] - max); + } + + double sum = 0.0; + + for (int i = 0; i < D; i++) { + sum += x[i]; + } + + sum += exp(static_cast(-max)); + + for (int i = 0; i < D; i++) { + y[i] = x[i] / (float)sum; + } + } + }); +} + +template <> +inline void ComputeSmoothSoftmaxInplace(float* score, int N, int D, ThreadPool* tp) { + MlasComputeSoftmax(score, score, N, D, false, true, tp); +} + template void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { ThreadPool::TryParallelFor(tp, N, D * 2.0, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { @@ -58,7 +99,7 @@ void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { template <> inline void ComputeAttentionSoftmaxInplace(float* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, tp); + MlasComputeSoftmax(score, score, N, D, false, false, tp); } template diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 137612a4bf902..70f8564a2cbf2 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -30,6 +30,8 @@ class GQAAttentionBase { do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; + local_window_size_ = has_local ? static_cast(info.GetAttrOrDefault("local_window_size", -1)) : -1; } @@ -40,6 +42,8 @@ class GQAAttentionBase { bool rotary_interleaved_; int local_window_size_; + bool use_smooth_softmax_; + template Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH @@ -195,10 +199,19 @@ class GQAAttentionBase { for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } - ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, - local_window_size_ + 1, nullptr); + if (use_smooth_softmax_) { + ComputeSmoothSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, + local_window_size_ + 1, nullptr); + } else { + ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, + local_window_size_ + 1, nullptr); + } } else { - ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + if (use_smooth_softmax_) { + ComputeSmoothSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + } else { + ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + } } // set causal [seq_causal_length, total_seqlen) to 0.f diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index a02f5c7329b9a..347cf946e6ff3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -303,7 +303,7 @@ Status FlashAttention( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, false, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH)); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 0463d3795b446..bcd87c1ab6251 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -121,6 +121,8 @@ struct Flash_fwd_params : public Qkv_params { bool is_rotary_interleaved = false; + bool smooth_softmax = false; + int num_splits = 0; // For split-KV version void* __restrict__ alibi_slopes_ptr = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 967c04c52b182..f875d31f5ca7a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -37,6 +37,7 @@ void set_params_fprop(Flash_fwd_params& params, float softmax_scale, bool is_causal, bool is_bf16, + bool use_smooth_softmax, bool kv_bsnh = true, int window_size_left = -1, int window_size_right = -1) { @@ -47,6 +48,7 @@ void set_params_fprop(Flash_fwd_params& params, params.o_ptr = out; params.is_bf16 = is_bf16; + params.smooth_softmax = use_smooth_softmax; // All stride are in elements, not bytes. if (kv_bsnh) { @@ -267,6 +269,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, float softmax_scale, bool is_causal, bool is_bf16, + bool use_smooth_softmax, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded @@ -293,6 +296,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, softmax_scale, is_causal, is_bf16, + use_smooth_softmax, kv_bsnh, local_window_size, is_causal ? 0 : -1); @@ -365,6 +369,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, softmax_scale, is_causal, is_bf16, + false, true, -1, is_causal ? 0 : -1); @@ -424,6 +429,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, const float softmax_scale, bool is_causal, bool is_bf16, + bool use_smooth_softmax, bool past_bsnh, // otherwise bnsh int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads @@ -456,6 +462,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, softmax_scale, is_causal, is_bf16, + use_smooth_softmax, past_bsnh, local_window_size, is_causal ? 0 : -1); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 4c59561449851..baad0a938d377 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -52,6 +52,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, float softmax_scale, bool is_causal, bool is_bf16, + bool use_smooth_softmax, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded @@ -105,6 +106,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, const float softmax_scale, bool is_causal, bool is_bf16, + bool use_smooth_softmax, bool past_bsnh, // otherwise bnsh int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index 1c8a93674a80b..b2aa3668a5be1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -346,7 +346,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // Epilogue - Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, params.smooth_softmax); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = flash::convert_type(acc_o); @@ -902,7 +902,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.smooth_softmax); Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index ba678b740d376..7e0095cb39bd9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -159,7 +159,7 @@ struct Softmax { }; template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale) { + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, bool smooth_softmax) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); @@ -167,7 +167,7 @@ struct Softmax { static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = row_sum(mi); + float sum = smooth_softmax ? row_sum(mi) + expf(-row_max(mi) * softmax_scale) : row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = inv_sum; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 797f9b0a1ea47..48ecfd7304f4b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -51,6 +51,7 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); + use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; kernel_options_ = this->GetAttentionKernelOptions(); @@ -98,6 +99,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { device_prop.maxThreadsPerBlock)); parameters.local_window_size = local_window_size_; parameters.is_unidirectional = is_unidirectional_; + parameters.use_smooth_softmax = use_smooth_softmax_; parameters.zeros_count = kZerosCount; parameters.zero_ptr = zeros_.get(); // parameters.left_padding = left_padding_; @@ -151,6 +153,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { #if USE_MEMORY_EFFICIENT_ATTENTION int sm = (device_prop.major * 10) + device_prop.minor; bool use_memory_efficient_attention = + !use_smooth_softmax_ && !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 4ff5b0a59f021..872fe9fe05ad2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -28,6 +28,7 @@ class GroupQueryAttention final : public CudaKernel { bool is_past_bsnh_; bool do_rotary_; bool rotary_interleaved_; + bool use_smooth_softmax_; float scale_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index b694de48d2961..63e94f95b04ff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -678,9 +678,9 @@ Status FlashAttention( reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr, batch_size, num_heads, kv_num_heads, head_size, sequence_length, parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, - parameters.is_packed_qkv)); + scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv)); // if (parameters.left_padding && parameters.is_prompt) { // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 334090e8f305f..dd3a06e3eb4ba 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1073,6 +1073,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Rotate using interleaved pattern. Default value is 0 (False).", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("smooth_softmax", + "Use a smooth factor in softmax.", + AttributeProto::INT, + static_cast(-1)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index e46105324a7fb..bea4b91ebaa79 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1013,6 +1013,7 @@ MlasComputeSoftmax( size_t N, size_t D, bool LogSoftmax, + bool SmoothSoftmax, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index f4c1e3da69289..73df23e64ca1f 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -71,6 +71,7 @@ MLAS_INTERNAL_DATA const float MlasMinimumF32Value = std::numeric_limits: struct MLAS_SOFTMAX_WORK_BLOCK { ptrdiff_t ThreadCountN; bool LogSoftmax; + bool SmoothSoftmax; const float* Input; float* Output; size_t N; @@ -81,7 +82,7 @@ MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasComputeExpVector( MLAS_FLOAT32X4 Vector - ) +) /*++ Routine Description: @@ -186,7 +187,7 @@ MlasComputeExpF32Kernel( const float* Input, float* Output, size_t N - ) +) /*++ Routine Description: @@ -208,7 +209,6 @@ Return Value: --*/ { while (N > 0) { - MLAS_FLOAT32X4 Vector; if (N >= 4) { @@ -228,7 +228,6 @@ Return Value: Vector = MlasComputeExpVector(Vector); if (N >= 4) { - MlasStoreFloat32x4(Output, Vector); Input += 4; @@ -236,7 +235,6 @@ Return Value: N -= 4; } else { - MlasStoreLaneFloat32x4<0>(Output, Vector); Input += 1; @@ -252,7 +250,7 @@ MlasComputeExp( const float* Input, float* Output, size_t N - ) +) /*++ Routine Description: @@ -287,7 +285,7 @@ MLAS_FLOAT32X4 MlasComputeSumExpVector( MLAS_FLOAT32X4 Vector, MLAS_FLOAT32X4 NegativeMaximumVector - ) +) /*++ Routine Description: @@ -379,7 +377,7 @@ MlasComputeSumExpF32Kernel( float* Output, size_t N, const float* NegativeMaximum - ) +) /*++ Routine Description: @@ -411,7 +409,6 @@ Return Value: float Accumulator = 0.0f; if (N >= 4) { - MLAS_FLOAT32X4 AccumulatorVector = MlasZeroFloat32x4(); #if !defined(MLAS_SSE2_INTRINSICS) @@ -426,7 +423,6 @@ Return Value: // while (N >= 8) { - MLAS_FLOAT32X4 Vector0 = MlasLoadFloat32x4(Input); MLAS_FLOAT32X4 Vector1 = MlasLoadFloat32x4(Input + 4); @@ -448,7 +444,6 @@ Return Value: #endif while (N >= 4) { - MLAS_FLOAT32X4 Vector = MlasLoadFloat32x4(Input); Vector = MlasComputeSumExpVector(Vector, NegativeMaximumVector); @@ -467,7 +462,6 @@ Return Value: } while (N > 0) { - #if defined(MLAS_SSE2_INTRINSICS) // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle and // use zeroes for the upper elements. @@ -498,7 +492,7 @@ MLASCALL MlasReduceMaximumF32Kernel( const float* Input, size_t N - ) +) /*++ Routine Description: @@ -521,17 +515,14 @@ Return Value: float Maximum = MlasMinimumF32Value; if (N >= 4) { - MLAS_FLOAT32X4 MaximumVector0 = MlasBroadcastFloat32x4(Maximum); if (N >= 16) { - MLAS_FLOAT32X4 MaximumVector1 = MaximumVector0; MLAS_FLOAT32X4 MaximumVector2 = MaximumVector0; MLAS_FLOAT32X4 MaximumVector3 = MaximumVector0; while (N >= 16) { - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input)); MaximumVector1 = MlasMaximumFloat32x4(MaximumVector1, MlasLoadFloat32x4(Input + 4)); MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MlasLoadFloat32x4(Input + 8)); @@ -547,7 +538,6 @@ Return Value: } while (N >= 4) { - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input)); Input += 4; @@ -558,7 +548,6 @@ Return Value: } while (N > 0) { - Maximum = std::max(Maximum, *Input); Input += 1; @@ -575,18 +564,16 @@ MlasReduceMinimumMaximumF32Kernel( float* Min, float* Max, size_t N - ) +) { float tmp_min = std::numeric_limits::max(); float tmp_max = std::numeric_limits::lowest(); if (N >= 4) { - MLAS_FLOAT32X4 MaximumVector0 = MlasBroadcastFloat32x4(tmp_max); MLAS_FLOAT32X4 MinimumVector0 = MlasBroadcastFloat32x4(tmp_min); if (N >= 16) { - MLAS_FLOAT32X4 MaximumVector1 = MaximumVector0; MLAS_FLOAT32X4 MaximumVector2 = MaximumVector0; MLAS_FLOAT32X4 MaximumVector3 = MaximumVector0; @@ -596,7 +583,6 @@ MlasReduceMinimumMaximumF32Kernel( MLAS_FLOAT32X4 MinimumVector3 = MinimumVector0; while (N >= 16) { - MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(Input); MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(Input + 4); MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(Input + 8); @@ -626,7 +612,6 @@ MlasReduceMinimumMaximumF32Kernel( } while (N >= 4) { - MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(Input); MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, InputVector0); @@ -641,7 +626,6 @@ MlasReduceMinimumMaximumF32Kernel( } while (N > 0) { - tmp_max = std::max(tmp_max, *Input); tmp_min = std::min(tmp_min, *Input); @@ -659,7 +643,7 @@ MlasComputeSoftmaxOutputF32Kernel( float* Output, size_t N, const float* Parameters - ) +) /*++ Routine Description: @@ -686,7 +670,6 @@ Return Value: const MLAS_FLOAT32X4 ScaleVector = MlasBroadcastFloat32x4(Scale); while (N >= 16) { - MLAS_FLOAT32X4 Vector0 = MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output)); MLAS_FLOAT32X4 Vector1 = MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output + 4)); MLAS_FLOAT32X4 Vector2 = MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output + 8)); @@ -702,7 +685,6 @@ Return Value: } while (N >= 4) { - MlasStoreFloat32x4(Output, MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output))); Output += 4; @@ -710,7 +692,6 @@ Return Value: } while (N > 0) { - *Output *= Scale; Output += 1; @@ -725,7 +706,7 @@ MlasComputeLogSoftmaxOutputF32Kernel( float* Output, size_t N, const float* Parameters - ) +) /*++ Routine Description: @@ -757,7 +738,6 @@ Return Value: const MLAS_FLOAT32X4 LogarithmVector = MlasBroadcastFloat32x4(Logarithm); while (N >= 16) { - MLAS_FLOAT32X4 Vector0 = MlasLoadFloat32x4(Input); MLAS_FLOAT32X4 Vector1 = MlasLoadFloat32x4(Input + 4); MLAS_FLOAT32X4 Vector2 = MlasLoadFloat32x4(Input + 8); @@ -784,7 +764,6 @@ Return Value: } while (N >= 4) { - MLAS_FLOAT32X4 Vector = MlasLoadFloat32x4(Input); Vector = MlasAddFloat32x4(Vector, NegativeMaximumVector); Vector = MlasSubtractFloat32x4(Vector, LogarithmVector); @@ -796,7 +775,6 @@ Return Value: } while (N > 0) { - *Output = *Input + NegativeMaximum - Logarithm; Input += 1; @@ -809,7 +787,7 @@ void MlasComputeSoftmaxThreaded( void* Context, ptrdiff_t Index - ) +) /*++ Routine Description: @@ -846,6 +824,7 @@ Return Value: const size_t D = WorkBlock->D; const bool LogSoftmax = WorkBlock->LogSoftmax; + const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; const float* Input = WorkBlock->Input + n * D; float* Output = WorkBlock->Output + n * D; @@ -857,7 +836,6 @@ Return Value: #endif while (CountN > 0) { - #if defined(MLAS_SSE2_INTRINSICS) // // Prefetch the next row of the input buffer. @@ -878,24 +856,30 @@ Return Value: float Maximum = MlasReduceMaximumF32Kernel(Input, D); #endif float NegativeMaximum = -Maximum; + if (SmoothSoftmax && NegativeMaximum > 0.0f) { + NegativeMaximum = 0.0f; + } - if (LogSoftmax) { - - // - // Compute the sum of the exponential functions for the row. - // - + // + // Compute the exponential function for each element of the row (save to Temp if provided) and + // compute the sum of these exponential functions. + // + float* Temp = LogSoftmax ? nullptr : Output; #if defined(MLAS_TARGET_AMD64) - float Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, nullptr, D, &NegativeMaximum); + float Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); #else - float Accumulation = MlasComputeSumExpF32Kernel(Input, nullptr, D, &NegativeMaximum); + float Accumulation = MlasComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); #endif + if (SmoothSoftmax) { + Accumulation += expf(NegativeMaximum); + } + + if (LogSoftmax) { // // Compute the log softmax output. // - - float Parameters[] = { NegativeMaximum, std::log(Accumulation)}; + float Parameters[] = {NegativeMaximum, std::log(Accumulation)}; #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) GetMlasPlatform().ComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); @@ -904,23 +888,10 @@ Return Value: #endif } else { - - // - // Compute the exponential function for each element of the row and - // compute the sum of these exponential functions. - // - -#if defined(MLAS_TARGET_AMD64) - float Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, Output, D, &NegativeMaximum); -#else - float Accumulation = MlasComputeSumExpF32Kernel(Input, Output, D, &NegativeMaximum); -#endif - // // Normalize the softmax output. // - - float Parameters[] = { 1.0f / Accumulation }; + float Parameters[] = {1.0f / Accumulation}; #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) GetMlasPlatform().ComputeSoftmaxOutputF32Kernel(Output, D, Parameters); @@ -943,8 +914,9 @@ MlasComputeSoftmax( size_t N, size_t D, bool LogSoftmax, + bool SmoothSoftmax, MLAS_THREADPOOL* ThreadPool - ) +) /*++ Routine Description: @@ -966,6 +938,8 @@ Routine Description: LogSoftmax - Supplies true if this is a log softmax operation, else false if this is a softmax operation. + SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation. + ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. @@ -982,6 +956,7 @@ Return Value: // WorkBlock.LogSoftmax = LogSoftmax; + WorkBlock.SmoothSoftmax = SmoothSoftmax; WorkBlock.Input = Input; WorkBlock.Output = Output; WorkBlock.N = N; diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index cae20b42725b8..2817dda9d0085 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -99,7 +99,7 @@ common::Status SoftmaxCPU(size_t N, float* Ydata, bool logarithmic, onnxruntime::concurrency::ThreadPool* thread_pool) { - MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, thread_pool); + MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index ed108eade05ab..2f4ebeabe043e 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -441,7 +441,7 @@ void batched_update_scores_inplace(gsl::span scores, int64_t num_batches_in, } if (use_mlas) { - MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, threadpool); + MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, threadpool); } else { while (s < s_end) { gsl::span scores_for_batch(s, s + batch_size); diff --git a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp index 6181be873f73e..65822eb294d7d 100644 --- a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp +++ b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp @@ -58,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) { std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory // warming up run - MlasComputeSoftmax(input, output, N, D, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); for (auto _ : state) { - MlasComputeSoftmax(input, output, N, D, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); } free(ptr.underlying_buffer); diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp index 4c5e11bbe9566..fb4ebbee77faf 100644 --- a/onnxruntime/test/mlas/unittest/test_softmax.cpp +++ b/onnxruntime/test/mlas/unittest/test_softmax.cpp @@ -23,13 +23,15 @@ class MlasSoftmaxTest : public MlasTestBase { Input[nd] = distribution(generator); } - Test(Input, Output, OutputReference, N, D, false); - Test(Input, Output, OutputReference, N, D, true); + Test(Input, Output, OutputReference, N, D, false, true); + Test(Input, Output, OutputReference, N, D, true, true); + Test(Input, Output, OutputReference, N, D, false, false); + Test(Input, Output, OutputReference, N, D, true, false); } - void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax) { - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, threadpool_); - ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax); + void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 1e-6f; constexpr float RelativeTolerance = 1e-6f; @@ -42,7 +44,7 @@ class MlasSoftmaxTest : public MlasTestBase { } } - void ReferenceSoftmax(const float* Input, float* Output, size_t N, size_t D, bool LogSoftmax) { + void ReferenceSoftmax(const float* Input, float* Output, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { for (size_t n = 0; n < N; n++) { float MaximumValue = std::numeric_limits::lowest(); @@ -50,6 +52,10 @@ class MlasSoftmaxTest : public MlasTestBase { MaximumValue = (std::max)(MaximumValue, Input[d]); } + if (SmoothSoftmax && MaximumValue < 0.0f) { + MaximumValue = 0.0f; + } + double Sum = 0.0; for (size_t d = 0; d < D; d++) { @@ -58,6 +64,10 @@ class MlasSoftmaxTest : public MlasTestBase { Output[d] = float(e); } + if (SmoothSoftmax) { + Sum += expf(-MaximumValue); + } + if (LogSoftmax) { float Scale = float(std::log(Sum)); diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py index 5e028519b9f34..53d015a029083 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -37,6 +37,7 @@ def plot_prompt_performance( head_size: int, max_seq_len: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, ): import triton @@ -55,6 +56,7 @@ def plot_prompt_performance( "kv_num_heads": kv_num_heads, "head_size": head_size, "local_window_size": local_window_size, + "use_smooth_softmax": use_smooth_softmax, }, ) ] @@ -68,6 +70,7 @@ def benchmark( kv_num_heads: int, head_size: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, device="cuda", ): warmup = 15 @@ -82,6 +85,7 @@ def benchmark( kv_num_heads=kv_num_heads, head_size=head_size, local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1, + use_smooth_softmax=use_smooth_softmax, device=device, is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"], ) @@ -103,6 +107,7 @@ def plot_token_performance( head_size: int, max_seq_len: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, ): import triton @@ -121,6 +126,7 @@ def plot_token_performance( "kv_num_heads": kv_num_heads, "head_size": head_size, "local_window_size": local_window_size, + "use_smooth_softmax": use_smooth_softmax, }, ) ] @@ -134,6 +140,7 @@ def benchmark( kv_num_heads: int, head_size: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, device="cuda", ): warmup = 15 @@ -150,6 +157,7 @@ def benchmark( local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1, do_rotary=True, # Most models use rotary positional embeddings is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"], + use_smooth_softmax=use_smooth_softmax, device=device, ) @@ -186,26 +194,29 @@ def run_performance_test(sm: int): for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures: for batch_size in [1, 4]: - plot_prompt_performance( - sm=sm, - batch_size=batch_size, - num_heads=num_heads, - kv_num_heads=kv_num_heads, - head_size=head_size, - max_seq_len=min(threshold, max_seq_len), - local_window_size=local_window_size, - model_name=model_name, - ) - plot_token_performance( - sm=sm, - batch_size=batch_size, - num_heads=num_heads, - kv_num_heads=kv_num_heads, - head_size=head_size, - max_seq_len=min(threshold, max_seq_len), - local_window_size=local_window_size, - model_name=model_name, - ) + for use_smooth_softmax in [False, True]: + plot_prompt_performance( + sm=sm, + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + max_seq_len=min(threshold, max_seq_len), + local_window_size=local_window_size, + use_smooth_softmax=use_smooth_softmax, + model_name=model_name, + ) + plot_token_performance( + sm=sm, + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + max_seq_len=min(threshold, max_seq_len), + local_window_size=local_window_size, + use_smooth_softmax=use_smooth_softmax, + model_name=model_name, + ) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py index b781ccf03f138..79cc8e41bf343 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py @@ -19,6 +19,7 @@ def save_results(results, filename): "Max Sequence Length", "Sequence Length", "Past Sequence Length", + "Smooth Softmax", "Model Name", ], ) @@ -36,6 +37,7 @@ def benchmark( sequence_length: int = 1, past_sequence_length: int = 0, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, model_name: str = "Llama3-8B", ): warmup = 15 @@ -50,6 +52,7 @@ def benchmark( kv_num_heads=kv_num_heads, head_size=head_size, local_window_size=local_window_size if local_window_size else -1, + use_smooth_softmax=use_smooth_softmax, do_rotary=True, # Most models use rotary positional embeddings is_packed_qkv=model_name in ["Phi-3-mini-128k", "Phi-3-small-128k"], device="cuda", @@ -93,6 +96,8 @@ def run_performance_tests(args): # Reduce max sequence length when GPU memory is not enough. threshold = 131072 if memory_in_gb > 24 else 65536 if memory_in_gb > 12 else 32768 + smooth_softmax = args.use_smooth_softmax + all_metrics = [] for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures: prompt_metrics_model = [] @@ -131,6 +136,7 @@ def run_performance_tests(args): sequence_length=sequence_length, max_seq_len=min(threshold, max_seq_len), local_window_size=local_window_size, + use_smooth_softmax=smooth_softmax, model_name=model_name, ) metrics = [*metrics, batch_size, max_seq_len, sequence_length, 0, model_name] @@ -169,9 +175,10 @@ def run_performance_tests(args): past_sequence_length=past_sequence_length, max_seq_len=min(threshold, max_seq_len), local_window_size=local_window_size, + use_smooth_softmax=smooth_softmax, model_name=model_name, ) - metrics = [*metrics, batch_size, max_seq_len, 1, past_sequence_length, model_name] + metrics = [*metrics, batch_size, max_seq_len, 1, past_sequence_length, smooth_softmax, model_name] token_metrics_model.append(metrics) all_metrics.append(metrics) # Calculate average inference interval and throughput for each model @@ -209,6 +216,15 @@ def run_performance_tests(args): default="flash_attention", help="GQA Kernel to use for benchmarking. Options: flash_attention, memory_efficient", ) + + parser.add_argument( + "--use_smooth_softmax", + required=False, + action="store_true", + help="test smooth softmax", + ) + parser.set_defaults(use_smooth_softmax=False) + args = parser.parse_args() if args.kernel == "memory_efficient": diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 84bf30b65a742..17b9276a882eb 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -22,6 +22,7 @@ from onnx import TensorProto, helper from packaging import version from parameterized import parameterized +from test_gqa_cpu import smooth_softmax_ref from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -222,6 +223,7 @@ def create_group_query_attention_graph_prompt( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -246,6 +248,7 @@ def create_group_query_attention_graph_prompt( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -408,6 +411,7 @@ def create_group_query_attention_graph_past( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -434,6 +438,7 @@ def create_group_query_attention_graph_past( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -783,6 +788,7 @@ def gqa_prompt_func( past_kv_format=Formats.BSNH, share_buffer=True, rotary_interleaved=False, + use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_prompt( config, @@ -792,6 +798,7 @@ def gqa_prompt_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None @@ -888,6 +895,7 @@ def gqa_past_func( share_buffer=True, window_size=-1, rotary_interleaved=False, + use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_past( config, @@ -897,6 +905,7 @@ def gqa_past_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() @@ -1033,6 +1042,7 @@ def attention_ref( window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, + use_smooth_softmax=False, ): """ Arguments: @@ -1079,7 +1089,12 @@ def attention_ref( q.device, ) scores.masked_fill_(local_mask, float("-inf")) - attention = torch.softmax(scores, dim=-1) + + if use_smooth_softmax: + attention = smooth_softmax_ref(scores) + else: + attention = torch.softmax(scores, dim=-1) + # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) @@ -1099,7 +1114,14 @@ def attention_ref( def attention_qkvpacked_ref( - qkv, key_padding_mask=None, dropout_p=0.0, dropout_mask=None, causal=False, upcast=True, reorder_ops=False + qkv, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + upcast=True, + reorder_ops=False, + use_smooth_softmax=False, ): return attention_ref( qkv[:, :, 0], @@ -1112,6 +1134,7 @@ def attention_qkvpacked_ref( upcast=upcast, causal=causal, reorder_ops=reorder_ops, + use_smooth_softmax=use_smooth_softmax, ) @@ -1192,6 +1215,7 @@ def parity_check_gqa_prompt( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1306,7 +1330,16 @@ def parity_check_gqa_prompt( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1330,6 +1363,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1346,6 +1380,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1374,6 +1409,7 @@ def parity_check_gqa_prompt_no_buff( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1465,7 +1501,16 @@ def parity_check_gqa_prompt_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + new_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1489,6 +1534,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1505,6 +1551,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1512,7 +1559,8 @@ def parity_check_gqa_prompt_no_buff( err_msg = ( f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}" + f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}," + f" use_smooth_softmax={use_smooth_softmax}" ) # Make sure past-present buffer updating correctly numpy.testing.assert_allclose( @@ -1533,6 +1581,7 @@ def parity_check_gqa_past( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1643,7 +1692,16 @@ def parity_check_gqa_past( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1667,6 +1725,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_past_func( @@ -1683,6 +1742,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -1711,6 +1771,7 @@ def parity_check_gqa_past_no_buff( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1827,7 +1888,16 @@ def parity_check_gqa_past_no_buff( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1851,6 +1921,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_past_func( @@ -1867,6 +1938,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -2137,6 +2209,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=False, ) parity_check_gqa_prompt_no_buff( config, @@ -2146,6 +2219,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=False, ) @parameterized.expand(gqa_no_past_flash_attention_test_cases()) @@ -2162,6 +2236,7 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=True, ) parity_check_gqa_prompt_no_buff( config, @@ -2170,6 +2245,7 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=False, ) @parameterized.expand(gqa_past_memory_efficient_test_cases()) @@ -2187,6 +2263,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=False, ) parity_check_gqa_past_no_buff( config, @@ -2196,6 +2273,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=False, ) @parameterized.expand(gqa_past_flash_attention_test_cases()) @@ -2214,6 +2292,7 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=False, ) parity_check_gqa_past_no_buff( config, @@ -2224,6 +2303,7 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=True, ) diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index b6b8aee15852f..eeba0baccf15b 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -145,6 +145,7 @@ def create_group_query_attention_graph_prompt( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -169,6 +170,7 @@ def create_group_query_attention_graph_prompt( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -331,6 +333,7 @@ def create_group_query_attention_graph_past( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -357,6 +360,7 @@ def create_group_query_attention_graph_past( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -667,6 +671,7 @@ def gqa_prompt_func( past_kv_format=Formats.BSNH, share_buffer=True, rotary_interleaved=False, + use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_prompt( config, @@ -676,6 +681,7 @@ def gqa_prompt_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None @@ -773,6 +779,7 @@ def gqa_past_func( share_buffer=True, window_size=-1, rotary_interleaved=False, + use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_past( config, @@ -782,6 +789,7 @@ def gqa_past_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() @@ -906,6 +914,13 @@ def construct_local_mask( ) +def smooth_softmax_ref(x): + x_max = x.amax(axis=-1, keepdim=True) + x_max = torch.maximum(x_max, torch.zeros_like(x_max)) + w = torch.exp(x - x_max) + return w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-x_max)) + + def attention_ref( q, k, @@ -918,6 +933,7 @@ def attention_ref( window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, + use_smooth_softmax=False, ): """ Arguments: @@ -935,6 +951,7 @@ def attention_ref( reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) without changing the math. This is to estimate the numerical error from operation reordering. + use_smooth_softmax: whether use smooth softmax or not Output: output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout @@ -964,10 +981,16 @@ def attention_ref( q.device, ) scores.masked_fill_(local_mask, float("-inf")) - attention = torch.softmax(scores, dim=-1) + + if use_smooth_softmax: + attention = smooth_softmax_ref(scores) + else: + attention = torch.softmax(scores, dim=-1) + # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: @@ -984,7 +1007,14 @@ def attention_ref( def attention_qkvpacked_ref( - qkv, key_padding_mask=None, dropout_p=0.0, dropout_mask=None, causal=False, upcast=True, reorder_ops=False + qkv, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + upcast=True, + reorder_ops=False, + use_smooth_softmax=False, ): return attention_ref( qkv[:, :, 0], @@ -997,6 +1027,7 @@ def attention_qkvpacked_ref( upcast=upcast, causal=causal, reorder_ops=reorder_ops, + use_smooth_softmax=use_smooth_softmax, ) @@ -1008,6 +1039,7 @@ def parity_check_gqa_prompt( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1108,7 +1140,16 @@ def parity_check_gqa_prompt( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1132,6 +1173,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1148,6 +1190,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1172,6 +1215,8 @@ def parity_check_gqa_prompt( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1201,6 +1246,7 @@ def parity_check_gqa_prompt_no_buff( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1275,7 +1321,16 @@ def parity_check_gqa_prompt_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + new_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1299,6 +1354,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1315,6 +1371,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1339,6 +1396,8 @@ def parity_check_gqa_prompt_no_buff( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1368,6 +1427,7 @@ def parity_check_gqa_past( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1473,7 +1533,16 @@ def parity_check_gqa_past( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1497,6 +1566,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_past_func( @@ -1513,6 +1583,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -1539,6 +1610,8 @@ def parity_check_gqa_past( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, " B:", config.batch_size, " S:", @@ -1566,6 +1639,7 @@ def parity_check_gqa_past_no_buff( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1677,7 +1751,16 @@ def parity_check_gqa_past_no_buff( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1701,6 +1784,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_past_func( @@ -1717,6 +1801,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -1737,6 +1822,8 @@ def parity_check_gqa_past_no_buff( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1787,26 +1874,29 @@ def test_gqa_no_past(self): for local in [False, True]: for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: for packed in [False, True]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_prompt( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_prompt_no_buff( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) + for use_smooth_softmax in [False, True]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) def test_gqa_past(self): print("-------- TEST GQA PAST (TOKEN GEN) ---------") @@ -1838,31 +1928,34 @@ def test_gqa_past(self): for local in [False, True]: for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: for packed in [False, True]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) + for use_smooth_softmax in [False, True]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index 688e6250fecbd..6a08d2101b100 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -13,6 +13,7 @@ import torch from benchmark_mha import InputFormats from onnx import TensorProto, helper +from test_gqa_cpu import smooth_softmax_ref from torch import Tensor from onnxruntime import InferenceSession, SessionOptions, get_available_providers @@ -42,6 +43,7 @@ def __init__( is_packed_qkv: bool = False, max_cache_sequence_length=None, max_rotary_sequence_length=None, + use_smooth_softmax: bool = False, ): self.operator = operator self.batch_size = batch_size @@ -72,6 +74,8 @@ def __init__( self.share_buffer = share_buffer self.is_packed_qkv = is_packed_qkv + self.use_smooth_softmax = use_smooth_softmax + def shape_dict(self): shapes = { "query": ( @@ -165,6 +169,7 @@ def __init__( is_packed_qkv=False, max_cache_sequence_length=None, max_rotary_sequence_length=None, + use_smooth_softmax: bool = False, ): super().__init__( "GroupQueryAttention", @@ -184,6 +189,7 @@ def __init__( is_packed_qkv=is_packed_qkv, max_cache_sequence_length=max_cache_sequence_length, max_rotary_sequence_length=max_rotary_sequence_length, + use_smooth_softmax=use_smooth_softmax, ) # local_window_size is for ORT only, not for Torch implementation. self.local_window_size = local_window_size @@ -528,6 +534,7 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): local_window_size=config.local_window_size, do_rotary=1 if config.do_rotary else 0, rotary_interleaved=config.rotary_interleaved, + smooth_softmax=1 if config.use_smooth_softmax else 0, domain="com.microsoft", ), ] @@ -611,7 +618,12 @@ def group_query_attention_reference( attn = torch.einsum("bhmd,bhnd->bhmn", query, key).float() * scale if mask is not None: attn = attn.masked_fill((1 - mask).bool(), float("-inf")) - attn = attn.softmax(-1) + + if config.use_smooth_softmax: + attn = smooth_softmax_ref(attn) + else: + attn = attn.softmax(-1) + attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value) result = attn_output.transpose(1, 2).contiguous()