Skip to content

Commit

Permalink
Fix Packed MultiHead Attention (#17996)
Browse files Browse the repository at this point in the history
### Description
Initialize previously unitialized parameters that were causing Op to
crash.



### Motivation and Context
Solves Cuda Memory Misalignment / Illegal Memory Access error when
FlashAttention was used in Packed Multi-Head Attention.
  • Loading branch information
aciddelgado authored and tianleiwu committed Oct 31, 2023
1 parent 00f85d3 commit 8ecd9a5
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 44 deletions.
88 changes: 48 additions & 40 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,81 +18,89 @@ constexpr int D_DIM = 2;
struct Qkv_params {
using index_t = uint32_t;
// The QKV matrices.
void* __restrict__ q_ptr;
void* __restrict__ k_ptr;
void* __restrict__ v_ptr;
void* __restrict__ q_ptr = nullptr;
void* __restrict__ k_ptr = nullptr;
void* __restrict__ v_ptr = nullptr;

// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
index_t q_batch_stride = 0;
index_t k_batch_stride = 0;
index_t v_batch_stride = 0;
index_t q_row_stride = 0;
index_t k_row_stride = 0;
index_t v_row_stride = 0;
index_t q_head_stride = 0;
index_t k_head_stride = 0;
index_t v_head_stride = 0;

// The number of heads.
int h, h_k;
int h = 0;
int h_k = 0;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
int h_h_k_ratio = 0; // precompute h / h_k,
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void* __restrict__ o_ptr;
void* __restrict__ oaccum_ptr;
void* __restrict__ o_ptr = nullptr;
void* __restrict__ oaccum_ptr = nullptr;

// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
index_t o_batch_stride = 0;
index_t o_row_stride = 0;
index_t o_head_stride = 0;

// The pointer to the P matrix.
void* __restrict__ p_ptr;
void* __restrict__ p_ptr = nullptr;

// The pointer to the softmax sum.
void* __restrict__ softmax_lse_ptr;
void* __restrict__ softmax_lseaccum_ptr;
void* __restrict__ softmax_lse_ptr = nullptr;
void* __restrict__ softmax_lseaccum_ptr = nullptr;

// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
int b = 0;
int seqlen_q = 0;
int seqlen_k = 0;
int seqlen_knew = 0;
int d = 0;
int seqlen_q_rounded = 0;
int seqlen_k_rounded = 0;
int d_rounded = 0;

// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
float scale_softmax = 0.0;
float scale_softmax_log2 = 0.0;

// array of length b+1 holding starting offset of each sequence.
int* __restrict__ cu_seqlens_q;
int* __restrict__ cu_seqlens_k;
int* __restrict__ cu_seqlens_q = nullptr;
int* __restrict__ cu_seqlens_k = nullptr;

int* __restrict__ blockmask;
int* __restrict__ blockmask = nullptr;

// The K_new and V_new matrices.
void* __restrict__ knew_ptr;
void* __restrict__ vnew_ptr;
void* __restrict__ knew_ptr = nullptr;
void* __restrict__ vnew_ptr = nullptr;

// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;
index_t knew_batch_stride = 0;
index_t vnew_batch_stride = 0;
index_t knew_row_stride = 0;
index_t vnew_row_stride = 0;
index_t knew_head_stride = 0;
index_t vnew_head_stride = 0;

bool is_bf16 = false;
bool is_causal;
bool is_causal = false;

// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
int num_splits; // For split-KV version
bool is_seqlens_k_cumulative = true;
int num_splits = 0; // For split-KV version

const cudaDeviceProp* dprops;
const cudaDeviceProp* dprops = nullptr;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
12 changes: 8 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ Status mha_fwd(const cudaDeviceProp& dprops,
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

Flash_fwd_params params;
params.dprops = &dprops;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
Expand All @@ -230,7 +229,7 @@ Status mha_fwd(const cudaDeviceProp& dprops,
softmax_scale,
is_causal,
kv_bsnh);

params.dprops = &dprops;
params.knew_ptr = nullptr;
params.vnew_ptr = nullptr;
params.knew_batch_stride = 0;
Expand Down Expand Up @@ -276,7 +275,6 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);

Flash_fwd_params params;
params.dprops = &dprops;
set_params_fprop(params,
batch_size,
max_seqlen_q, max_seqlen_k,
Expand All @@ -290,6 +288,12 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
softmax_lse,
softmax_scale,
is_causal);
params.dprops = &dprops;
params.num_splits = 0;
params.softmax_lseaccum_ptr = nullptr;
params.oaccum_ptr = nullptr;
params.knew_ptr = nullptr;
params.vnew_ptr = nullptr;
run_mha_fwd(params, stream);
return Status::OK();
}
Expand Down Expand Up @@ -336,7 +340,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

Flash_fwd_params params;
params.dprops = &dprops;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
Expand All @@ -351,6 +354,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
softmax_scale,
is_causal,
past_bsnh);
params.dprops = &dprops;

if (k != nullptr && v != nullptr) {
params.seqlen_knew = seqlen_k_new;
Expand Down

0 comments on commit 8ecd9a5

Please sign in to comment.