Skip to content

Commit

Permalink
softcap gqa (#21683)
Browse files Browse the repository at this point in the history
### Description
Implement softcap for gqa.

### Motivation and Context
Fixes certain models like Gemma-2 which need softcap to work so they
don't output nan's.
  • Loading branch information
aciddelgado authored Aug 31, 2024
1 parent 5dee95f commit 509cb54
Show file tree
Hide file tree
Showing 26 changed files with 366 additions and 160 deletions.
40 changes: 32 additions & 8 deletions cmake/patches/cutlass/cutlass_3.5.0.patch
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
index 4c80f549..34327633 100644
index 4c80f549..5ad610c8 100644
--- a/examples/41_fused_multi_head_attention/kernel_forward.h
+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
@@ -221,6 +221,8 @@ struct AttentionKernel {
@@ -189,6 +189,7 @@ struct AttentionKernel {

// Scale
accum_t scale = 0.0;
+ accum_t softcap = 0.0;

// Dimensions/strides
int32_t head_dim = 0;
@@ -221,6 +222,8 @@ struct AttentionKernel {
int32_t num_batches = 0;
int32_t num_heads = 0;

Expand All @@ -11,7 +19,23 @@ index 4c80f549..34327633 100644
// dropout
bool use_dropout = false;
unsigned long long dropout_batch_head_rng_offset = 0;
@@ -897,7 +899,8 @@ struct AttentionKernel {
@@ -818,6 +821,15 @@ struct AttentionKernel {
accum =
cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
}
+
+ // apply softcap if applicable
+ if (p.softcap > 0.0) {
+ accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(1.0 / p.softcap, accum);
+ for (int i = 0; i < accum.size(); ++i) {
+ accum[i] = cutlass::fast_tanh(accum[i]);
+ }
+ accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.softcap, accum);
+ }

// apply attention bias if applicable
if (kSupportsBias && p.attn_bias_ptr != nullptr) {
@@ -897,7 +909,8 @@ struct AttentionKernel {
p.num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
Expand All @@ -21,7 +45,7 @@ index 4c80f549..34327633 100644

// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
@@ -1166,7 +1169,8 @@ struct AttentionKernel {
@@ -1166,7 +1179,8 @@ struct AttentionKernel {
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
Expand All @@ -31,7 +55,7 @@ index 4c80f549..34327633 100644
/* Iterates on the accumulator and corresponding position on result matrix

(1) Update `mi[r]` to the max value of the row `r`
@@ -1257,7 +1261,7 @@ struct AttentionKernel {
@@ -1257,7 +1271,7 @@ struct AttentionKernel {
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
Expand All @@ -40,7 +64,7 @@ index 4c80f549..34327633 100644
[&](int accum_m, int accum_n, int idx) {
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
@@ -1294,7 +1298,7 @@ struct AttentionKernel {
@@ -1294,7 +1308,7 @@ struct AttentionKernel {
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
Expand All @@ -50,7 +74,7 @@ index 4c80f549..34327633 100644
}

diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
index 964d2ff3..b366bc14 100644
index 964d2ff3..676ba768 100644
--- a/include/cutlass/functional.h
+++ b/include/cutlass/functional.h
@@ -39,6 +39,7 @@
Expand All @@ -73,4 +97,4 @@ index 964d2ff3..b366bc14 100644
+#endif
#else
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
#endif
#endif
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2543,6 +2543,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>smooth_softmax</tt> : int</dt>
<dd>Use a smooth factor in softmax.</dd>
<dt><tt>softcap</tt> : float</dt>
<dd>Softcap value for attention weights. Default value is 0.</dd>
</dl>

#### Inputs (7 - 9)
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ struct GroupQueryAttentionParameters {
bool rotary_interleaved;
bool use_smooth_softmax;
float scale;
float softcap;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
int zeros_count;
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ inline void ComputeAttentionSoftmaxInplace(float* score, int N, int D, ThreadPoo
MlasComputeSoftmax(score, score, N, D, false, false, tp);
}

template <typename T>
void ComputeAttentionSoftcapInplace(T* scores, int sequence_length, float softcap) {
for (int i = 0; i < sequence_length; i++) {
scores[i] = scores[i] / softcap;
scores[i] = std::tanh(scores[i]);
scores[i] = scores[i] * softcap;
}
}

template void ComputeAttentionSoftcapInplace<float>(float* scores, int sequence_length, float softcap);

template <typename T>
void PrepareMask(const int32_t* mask_index,
gsl::span<const int64_t> mask_index_dims,
Expand Down
11 changes: 10 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class GQAAttentionBase {
kv_num_heads_ = static_cast<int>(kv_num_heads);

scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
softcap_ = info.GetAttrOrDefault<float>("softcap", 0.0f);

do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
Expand All @@ -38,7 +39,8 @@ class GQAAttentionBase {
int num_heads_; // number of attention heads of Q
int kv_num_heads_; // number of attention heads of K or V
float scale_; // the scaling factor applied before softmax
bool do_rotary_; // whether or not to use rotary embeddings
float softcap_;
bool do_rotary_; // whether or not to use rotary embeddings
bool rotary_interleaved_;
int local_window_size_;

Expand Down Expand Up @@ -199,6 +201,10 @@ class GQAAttentionBase {
for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
output_softmax[total_seq_id] = 0.f;
}
if (softcap_ > 0.f) {
ComputeAttentionSoftcapInplace(output_softmax + seq_causal_length - local_window_size_ - 1,
local_window_size_ + 1, softcap_);
}
if (use_smooth_softmax_) {
ComputeSmoothSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1,
local_window_size_ + 1, nullptr);
Expand All @@ -207,6 +213,9 @@ class GQAAttentionBase {
local_window_size_ + 1, nullptr);
}
} else {
if (softcap_ > 0.f) {
ComputeAttentionSoftcapInplace(output_softmax, seq_causal_length, softcap_);
}
if (use_smooth_softmax_) {
ComputeSmoothSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr);
} else {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
const Tensor* sin_cache = context->Input<Tensor>(8);

GroupQueryAttentionParameters parameters = {};
constexpr float scale = 1.0f;
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
key,
value,
Expand All @@ -63,7 +62,8 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
kv_num_heads_,
seqlens_k,
total_seqlen,
scale));
scale_,
softcap_));

const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ Status CheckInputs(const Tensor* query,
int kv_num_heads,
const Tensor* seqlens_k,
const Tensor* total_seqlen,
float scale) {
float scale,
float softcap) {
// Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache
// past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
// past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
Expand Down Expand Up @@ -236,6 +237,7 @@ Status CheckInputs(const Tensor* query,
output_parameters->is_unidirectional = true;
output_parameters->is_prompt = is_prompt;
output_parameters->scale = scale;
output_parameters->softcap = softcap;
output_parameters->qkv_format = qkv_format;
output_parameters->past_kv_format = past_kv_format;
}
Expand All @@ -256,12 +258,13 @@ Status CheckInputs(const Tensor* query,
const Tensor* seqlens_k,
const Tensor* total_seqlen,
float scale,
float softcap,
int max_threads_per_block) {
if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}

return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale);
return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap);
}
} // namespace group_query_attention_helper
} // namespace contrib
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ Status FlashAttention(
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast<void*>(data.scratch),
parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size,
parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, false,
parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH));
parameters.sequence_length, parameters.total_sequence_length, scale, 0.0, parameters.is_unidirectional, is_bf16,
false, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum), data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH));

return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.head_dim_value = params.v_head_size;

p.scale = params.scale;
p.softcap = params.softcap;

// When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel
p.num_queries = params.sequence_length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct MemoryEfficientAttentionParams {
bool use_smooth_softmax;

float scale;
float softcap = 0.0;

int32_t* seqstart_q_ptr;
int32_t* seqstart_k_ptr;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct Flash_fwd_params : public Qkv_params {
// The scaling factors for the kernel.
float scale_softmax = 0.0;
float scale_softmax_log2 = 0.0;
float softcap = 0.0;

// array of length b+1 holding starting offset of each sequence.
int* __restrict__ cu_seqlens_q = nullptr;
Expand Down
21 changes: 18 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void set_params_fprop(Flash_fwd_params& params,
void* p_d,
void* softmax_lse_d,
float softmax_scale,
float softcap,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
Expand Down Expand Up @@ -111,8 +112,16 @@ void set_params_fprop(Flash_fwd_params& params,
params.d_rounded = head_size_rounded;

// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
if (softcap > 0.0) {
params.softcap = softmax_scale / softcap;
params.scale_softmax = softcap;
params.scale_softmax_log2 = softcap * M_LOG2E;
} else {
// Remove potential NaN
params.softcap = 0.0;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
}

// In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API separates
// local and causal, meaning when we have local window size
Expand Down Expand Up @@ -267,6 +276,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
int seqlen_q,
int seqlen_k,
float softmax_scale,
const float softcap,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
Expand Down Expand Up @@ -294,6 +304,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
/*p_ptr=*/nullptr,
softmax_lse,
softmax_scale,
softcap,
is_causal,
is_bf16,
use_smooth_softmax,
Expand Down Expand Up @@ -343,6 +354,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
int max_seqlen_q,
int max_seqlen_k,
float softmax_scale,
const float softcap,
bool is_causal,
bool is_bf16,
int max_num_blocks_per_seq,
Expand All @@ -367,6 +379,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
/*p_ptr=*/nullptr,
softmax_lse,
softmax_scale,
softcap,
is_causal,
is_bf16,
false,
Expand Down Expand Up @@ -427,6 +440,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int seqlen_k_new,
int rotary_dim,
const float softmax_scale,
const float softcap,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
Expand All @@ -440,7 +454,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int max_num_blocks_per_seq,
int page_block_size) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
const bool paged_KV = block_table != nullptr;
Expand All @@ -460,6 +474,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
/*p_ptr=*/nullptr,
softmax_lse,
softmax_scale,
softcap,
is_causal,
is_bf16,
use_smooth_softmax,
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
int seqlen_q,
int seqlen_k,
float softmax_scale,
const float softcap,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
Expand Down Expand Up @@ -77,6 +78,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
int max_seqlen_q,
int max_seqlen_k,
float softmax_scale,
const float softcap,
bool is_causal,
bool is_bf16,
int max_num_blocks_per_seq = 0,
Expand Down Expand Up @@ -104,6 +106,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int seqlen_k_new,
int rotary_dim,
const float softmax_scale,
const float softcap,
bool is_causal,
bool is_bf16,
bool use_smooth_softmax,
Expand Down
Loading

0 comments on commit 509cb54

Please sign in to comment.