Skip to content

Commit

Permalink
Support Smooth Softmax in GroupQueryAttention (#21867)
Browse files Browse the repository at this point in the history
### Description

Softmax (formula 1) is like the following:
```math
y_{i} = \frac{exp(x_{i})}{\sum_{i} exp(x_{i})}
```
After applying softmax, each element will be in the range of $(0, 1)$,
and the elements will add up to 1, so that they can be interpreted as
probabilities.

However, in language model, softmax has two issues:
* When all elements are -inf (for example, a whole row is masked when a
query token is padding), the result is not defined since exp(-inf)=0 and
divided-by-zero is encountered in the above formula.
* Why do we need normalize in a way that each query word are treated as
equal important (each row has sum equals to1)?

**Smooth Softmax** (formula 2) is a modified version that introduces a
smooth factor like the following:
```math
s_{i} = \frac{exp(x_{i})}{1+ \sum_{i} exp(x_{i})}
```

This formula could tackle the above two issues:
* It could handle the special case that all elements are -inf: the
result $s_{i}$ is 0 for every element in such case.
* Sum of all elements $\sum_{i}{s_{i}} = \frac{\sum_{i}{exp(x_{i})}}{1+
\sum_{i} exp(x_{i})}$ is in the range of (0, 1), so that we can train
the model to assign different importance to different query words.

Since exponential is prone to overflow or underflow, to get stable
result, formula 3 can be used:
```math
s_{i} = \frac{exp(x_{i} + c)}{exp(c)+ \sum_{i} exp(x_{i} +c)}
```
c can be any value in theory. In practical, choice of constant c shall
avoid $exp(c)$ and $exp(x_{i} +c)$ overflow (or underflow) at the same
time. A reasonable choice is like formula 4:
```math
c=-\max_{i} \{ x_i \}
```
or  apply a constraint that c <=0 like the following formula 5:

```math
c=-\max(0, \max_{i} \{ x_i \})
```
The latter one (formula 5) ensures that $s_{i}$ will fallback to formula
2 when all elements are negative.

For CPU provider, smooth softmax is implemented in MLAS. CPU
implementation uses formula 5.

@wangyems implemented the smooth softmax in flash attention for CUDA,
which requires Ampere or newer GPU. The implementation of smooth softmax
in flash attention uses formula 4.

---------

Co-authored-by: Ye Wang
  • Loading branch information
tianleiwu authored Aug 27, 2024
1 parent 99bc45d commit 6e57576
Show file tree
Hide file tree
Showing 25 changed files with 435 additions and 161 deletions.
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2541,6 +2541,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>smooth_softmax</tt> : int</dt>
<dd>Use a smooth factor in softmax.</dd>
</dl>

#### Inputs (7 - 9)
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ struct GroupQueryAttentionParameters {
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool do_rotary;
bool rotary_interleaved;
bool use_smooth_softmax;
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
Expand Down
43 changes: 42 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,47 @@ using onnxruntime::concurrency::ThreadPool;
namespace onnxruntime {
namespace contrib {

template <typename T>
void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
ThreadPool::TryParallelFor(tp, N, D * 2.0, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t j = begin; j != end; ++j) {
float* x = reinterpret_cast<T*>(score) + j * D;
float* y = x;

float max = -std::numeric_limits<float>::infinity();
for (int i = 0; i < D; i++) {
if (max < x[i])
max = x[i];
}

if (max < 0.0f) {
max = 0.0f;
}

for (int i = 0; i < D; i++) {
y[i] = expf(x[i] - max);
}

double sum = 0.0;

for (int i = 0; i < D; i++) {
sum += x[i];
}

sum += exp(static_cast<double>(-max));

for (int i = 0; i < D; i++) {
y[i] = x[i] / (float)sum;
}
}
});
}

template <>
inline void ComputeSmoothSoftmaxInplace(float* score, int N, int D, ThreadPool* tp) {
MlasComputeSoftmax(score, score, N, D, false, true, tp);
}

template <typename T>
void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {
ThreadPool::TryParallelFor(tp, N, D * 2.0, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
Expand Down Expand Up @@ -58,7 +99,7 @@ void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) {

template <>
inline void ComputeAttentionSoftmaxInplace(float* score, int N, int D, ThreadPool* tp) {
MlasComputeSoftmax(score, score, N, D, false, tp);
MlasComputeSoftmax(score, score, N, D, false, false, tp);
}

template <typename T>
Expand Down
19 changes: 16 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class GQAAttentionBase {
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;

use_smooth_softmax_ = info.GetAttrOrDefault<int64_t>("smooth_softmax", 0) == 1;

local_window_size_ = has_local ? static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1)) : -1;
}

Expand All @@ -40,6 +42,8 @@ class GQAAttentionBase {
bool rotary_interleaved_;
int local_window_size_;

bool use_smooth_softmax_;

template <typename T>
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
const T* K, // K data with shape BxN_kvxSxH
Expand Down Expand Up @@ -195,10 +199,19 @@ class GQAAttentionBase {
for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
output_softmax[total_seq_id] = 0.f;
}
ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1,
local_window_size_ + 1, nullptr);
if (use_smooth_softmax_) {
ComputeSmoothSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1,
local_window_size_ + 1, nullptr);
} else {
ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1,
local_window_size_ + 1, nullptr);
}
} else {
ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr);
if (use_smooth_softmax_) {
ComputeSmoothSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr);
} else {
ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr);
}
}

// set causal [seq_causal_length, total_seqlen) to 0.f
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ Status FlashAttention(
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast<void*>(data.scratch),
parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size,
parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16,
parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, false,
parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH));

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ struct Flash_fwd_params : public Qkv_params {

bool is_rotary_interleaved = false;

bool smooth_softmax = false;

int num_splits = 0; // For split-KV version

void* __restrict__ alibi_slopes_ptr = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void set_params_fprop(Flash_fwd_params& params,
float softmax_scale,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
bool kv_bsnh = true,
int window_size_left = -1,
int window_size_right = -1) {
Expand All @@ -47,6 +48,7 @@ void set_params_fprop(Flash_fwd_params& params,
params.o_ptr = out;

params.is_bf16 = is_bf16;
params.smooth_softmax = use_smooth_softmax;

// All stride are in elements, not bytes.
if (kv_bsnh) {
Expand Down Expand Up @@ -267,6 +269,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
float softmax_scale,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
Expand All @@ -293,6 +296,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
softmax_scale,
is_causal,
is_bf16,
use_smooth_softmax,
kv_bsnh,
local_window_size,
is_causal ? 0 : -1);
Expand Down Expand Up @@ -365,6 +369,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
softmax_scale,
is_causal,
is_bf16,
false,
true,
-1,
is_causal ? 0 : -1);
Expand Down Expand Up @@ -424,6 +429,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
const float softmax_scale,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
bool past_bsnh, // otherwise bnsh
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
Expand Down Expand Up @@ -456,6 +462,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
softmax_scale,
is_causal,
is_bf16,
use_smooth_softmax,
past_bsnh,
local_window_size,
is_causal ? 0 : -1);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
float softmax_scale,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
Expand Down Expand Up @@ -105,6 +106,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
const float softmax_scale,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
bool past_bsnh, // otherwise bnsh
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi

// Epilogue

Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax);
Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, params.smooth_softmax);

// Convert acc_o from fp32 to fp16/bf16
Tensor rO = flash::convert_type<Element>(acc_o);
Expand Down Expand Up @@ -902,7 +902,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons

// Epilogue

Tensor lse = softmax.template normalize_softmax_lse<Split>(acc_o, params.scale_softmax);
Tensor lse = softmax.template normalize_softmax_lse<Split>(acc_o, params.scale_softmax, params.smooth_softmax);

Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO*>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ struct Softmax {
};

template <bool Split = false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale) {
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, bool smooth_softmax) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float sum = smooth_softmax ? row_sum(mi) + expf(-row_max(mi) * softmax_scale) : row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = inv_sum;
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
use_smooth_softmax_ = info.GetAttrOrDefault<int64_t>("smooth_softmax", 0) == 1;

kernel_options_ = this->GetAttentionKernelOptions();

Expand Down Expand Up @@ -98,6 +99,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
device_prop.maxThreadsPerBlock));
parameters.local_window_size = local_window_size_;
parameters.is_unidirectional = is_unidirectional_;
parameters.use_smooth_softmax = use_smooth_softmax_;
parameters.zeros_count = kZerosCount;
parameters.zero_ptr = zeros_.get();
// parameters.left_padding = left_padding_;
Expand Down Expand Up @@ -151,6 +153,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
!use_smooth_softmax_ &&
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class GroupQueryAttention final : public CudaKernel {
bool is_past_bsnh_;
bool do_rotary_;
bool rotary_interleaved_;
bool use_smooth_softmax_;
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,9 @@ Status FlashAttention(
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved,
parameters.is_packed_qkv));
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv));

// if (parameters.left_padding && parameters.is_prompt) {
// ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Rotate using interleaved pattern. Default value is 0 (False).",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("smooth_softmax",
"Use a smooth factor in softmax.",
AttributeProto::INT,
static_cast<int64_t>(-1))
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape"
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,7 @@ MlasComputeSoftmax(
size_t N,
size_t D,
bool LogSoftmax,
bool SmoothSoftmax,
MLAS_THREADPOOL* ThreadPool
);

Expand Down
Loading

0 comments on commit 6e57576

Please sign in to comment.