Skip to content

Commit

Permalink
Split KV on MHA and Attention ops (microsoft#18007)
Browse files Browse the repository at this point in the history
### Description
Implement Split KV optimization for FlashAttention in MHA and Attention
operators.

### Motivation and Context
Can help further accelerate these ops.
  • Loading branch information
aciddelgado authored Nov 1, 2023
1 parent c181159 commit 819b5a3
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 33 deletions.
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct AttentionParameters {
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
int num_splits;
bool is_unidirectional;
bool past_present_share_buffer;
bool do_rotary;
Expand Down Expand Up @@ -95,9 +96,9 @@ struct GroupQueryAttentionParameters {
int head_size;
int kv_hidden_size;
int kv_num_heads;
int num_splits; // number of splits for splitkv
bool is_unidirectional; // causal
float scale;
int num_splits; // number of splits for splitkv
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
};
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,24 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
use_flash_attention = false;
}
// Allocate buffers
size_t softmax_lse_accum_bytes = 0;
size_t out_accum_bytes = 0;
if (use_flash_attention) {
using namespace std;
auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads,
parameters.head_size, device_prop.multiProcessorCount);
parameters.num_splits = num_splits;
softmax_lse_accum_bytes = slse_accum_bytes;
out_accum_bytes = o_accum_bytes;
}
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
auto out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());
#else
constexpr bool use_flash_attention = false;
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
auto out_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif

if (!use_flash_attention) {
Expand Down Expand Up @@ -279,6 +295,12 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
if (softmax_lse_accum_buffer != nullptr) {
data.softmax_lse_accum = reinterpret_cast<CudaT*>(softmax_lse_accum_buffer.get());
}
if (out_accum_buffer != nullptr) {
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
}

return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
}
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ Status FlashAttention(
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
device_prop, stream, query, key, value, data.output, reinterpret_cast<void*>(data.scratch),
parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size,
parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional));
parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional,
parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
true));

DUMP_TENSOR("flash attention output", data.output,
parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size);
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ struct AttentionData {
T* v = nullptr;
T* scratch = nullptr;
AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH;

// Flash buffers
T* softmax_lse = nullptr;
T* softmax_lse_accum = nullptr;
T* out_accum = nullptr;
};

template <typename T>
Expand Down
27 changes: 23 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,10 @@ void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split
// So we find the best efficiency, then find the smallest number of splits that gets 85%
// of the best efficiency.
int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs,
int max_splits, bool new_kv, bool is_sm8x) {
int max_splits) {
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = is_sm8x ? (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64))
: (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64));
const int num_n_blocks = (seqlen_k + (!new_kv ? 0 : seqlen_q) + block_n - 1) / block_n;
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
Expand Down Expand Up @@ -190,6 +189,26 @@ int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_hea
return 1;
}

// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes)
std::tuple<int, int, int> get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads,
int head_size, int num_SMs) {
int max_splits = 128;
// split kv buffers
int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size,
num_SMs, max_splits);
if (num_splits > 1) {
// softmax_lse_accum buffer
int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q);
// out_accum buffer
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded);
return {num_splits, softmax_lse_accum_bytes, out_accum_bytes};
} else {
return {0, 0, 0};
}
}

Status mha_fwd(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#if USE_FLASH_ATTENTION

#include "core/providers/cuda/cuda_common.h"
#include <tuple>

namespace onnxruntime {
namespace flash {
Expand Down Expand Up @@ -99,10 +100,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
);

size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);
size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q);
size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded);

int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, int max_splits, bool new_kv, bool is_sm8x);
std::tuple<int, int, int> get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads,
int head_size, int num_SMs);

bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,9 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) {

template <typename T, int Headdim>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) {
bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0;
constexpr int kBlockM = 64; // Fixed for all head dimensions
if (!is_sm8x) { // A100, H100
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
} else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
}
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
}

template <typename T>
Expand Down
22 changes: 8 additions & 14 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,22 +116,16 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
size_t out_accum_bytes = 0;
size_t seqlens_k_bytes = 0;
if (use_flash_attention) {
// softmax buffer
softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads);
// split kv buffers
parameters.num_splits = onnxruntime::flash::num_splits_heuristic(
// split kv buffer
using namespace std;
auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads,
parameters.head_size, device_prop.multiProcessorCount, 128, false,
device_prop.major == 8 && device_prop.minor > 0);
if (parameters.num_splits > 1) {
// softmax_lse_accum buffer
softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size(
parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length);
// out_accum buffer
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(parameters.head_size, 32);
out_accum_bytes = onnxruntime::flash::get_out_accum_size(
parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, head_size_rounded);
}
parameters.head_size, device_prop.multiProcessorCount);
parameters.num_splits = num_splits;
softmax_lse_accum_bytes = slse_accum_bytes;
out_accum_bytes = o_accum_bytes;
// seqlens_k buffer
if (past_key != nullptr) {
seqlens_k_bytes = sizeof(int) * parameters.batch_size;
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,24 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
use_flash_attention = false;
}
// Allocate buffers
size_t softmax_lse_accum_bytes = 0;
size_t out_accum_bytes = 0;
if (use_flash_attention) {
using namespace std;
auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads,
parameters.head_size, device_prop.multiProcessorCount);
parameters.num_splits = num_splits;
softmax_lse_accum_bytes = slse_accum_bytes;
out_accum_bytes = o_accum_bytes;
}
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
auto out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());
#else
constexpr bool use_flash_attention = false;
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
auto out_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif

bool use_fused_cross_attention = !use_flash_attention &&
Expand Down Expand Up @@ -291,6 +307,12 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.use_memory_efficient_attention = use_memory_efficient_attention;
data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_);
data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_);
if (softmax_lse_accum_buffer != nullptr) {
data.softmax_lse_accum = reinterpret_cast<CudaT*>(softmax_lse_accum_buffer.get());
}
if (out_accum_buffer != nullptr) {
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
}

cublasHandle_t cublas = GetCublasHandle(context);

Expand Down

0 comments on commit 819b5a3

Please sign in to comment.