Skip to content

Commit

Permalink
Refactor Attention cuda kernel (#17578)
Browse files Browse the repository at this point in the history
* Break QkvToContext into small functions. Each fused and unfused kernel
will have separated function.
* Move DecoderAttention kernel to separated file
* Move KV cache related kernel to attention_kv_cache.cu

### Motivation and Context
To make the code easier to maintain.
  • Loading branch information
tianleiwu authored Sep 19, 2023
1 parent 068300d commit 730fab3
Show file tree
Hide file tree
Showing 12 changed files with 904 additions and 728 deletions.
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ set(contrib_ops_excluded_files
"bert/attention_softmax.h"
"bert/attention_softmax.cu"
"bert/attention_prepare_qkv.cu"
"bert/decoder_attention_impl.h"
"bert/decoder_attention_impl.cu"
"bert/decoder_masked_multihead_attention.h"
"bert/decoder_masked_multihead_attention.cc"
"bert/decoder_masked_self_attention.h"
Expand Down
931 changes: 332 additions & 599 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

Large diffs are not rendered by default.

56 changes: 19 additions & 37 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,20 @@ struct AttentionData {

mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr;
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr;
};

// Intermediate data pointers available after PrepareQKV
template <typename T>
struct QkvData {
// Intermediate data
T* q = nullptr;
T* k = nullptr;
T* v = nullptr;
T* after_v = nullptr; // pointer right after v
AttentionQkvFormat format = AttentionQkvFormat::Q_K_V_BSNH;
T* scratch = nullptr;
AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
};

template <typename T>
Status PrepareQkv(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
QkvData<T>& qkv_data);
int max_threads_per_block);

template <typename T>
Status QkvToContext(
Expand All @@ -108,33 +104,6 @@ Status QkvToContext(
contrib::AttentionParameters& parameters,
AttentionData<T>& data);

Status LaunchDecoderAttentionKernel(
const cudaDeviceProp& prop, // Device Properties
Stream* stream, // ORT Stream
cublasHandle_t& cublas, // Cublas handle
const size_t element_size, // Element size of input tensor
const int batch_size, // Batch size (B)
const int sequence_length, // Sequence length (S)
const int kv_sequence_length, // Key/Value/Cache sequence length
const int num_heads, // Number of attention heads (N)
const int head_size, // Hidden size per head (H)
const bool static_kv, // Whether cross attention or not
const bool use_past, // Whether use cache or not
const bool has_layer_state, // Whether output cache or not
const bool has_key_padding_mask, // Whether use key_padding_mask or not
const float mask_filter_value, // Mask filter value
const void* gemm_query_buffer, // Query buffer
const void* gemm_kv_buffer, // Key and value buffer
const bool* key_padding_mask, // Key padding mask
const void* key_cache, // Input key cache
const void* value_cache, // Input value cache
void* qkv_buffer, // Temporary buffer
void* workspace_buffer, // Temporary buffer
void* output, // Output tensor
void* new_key_cache, // New_key_cache tensor
void* new_value_cache // New_value_cache tensor
);

// BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true)
Status LaunchTransCtx(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
Expand Down Expand Up @@ -184,14 +153,27 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<T>& data,
QkvData<T>& qkv);
AttentionData<T>& data);

template <typename T>
Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
const int max_sequence_length,
const int past_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const T* biases,
const T* qkv_buffer,
T* present);

template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cu_inc/common.cuh"

using namespace onnxruntime::cuda;

Expand Down Expand Up @@ -244,48 +245,48 @@ Status LaunchConcatPastToPresent(cudaStream_t stream,
present);
}

#ifndef USE_ROCM // exclude from hipify
#ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP

template <typename T>
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<T>& data,
QkvData<T>& qkv) {
AttentionData<T>& data) {
// Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length.
// past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH)
// past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH)
// When there is past state, the head size for Q/K/V shall be same: H == H_v.

if (nullptr != data.present) {
assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH || qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);

ORT_RETURN_IF_ERROR(
LaunchConcatPastToPresent(
stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, data.past, qkv.k, data.present));
max_threads_per_block, data.past, data.k, data.present));

// Update pointers to present_k and present_v.
qkv.k = data.present;
qkv.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size;
}

if (nullptr != data.past_key || nullptr != data.present_key) {
data.k = data.present;
data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size;
} else if (nullptr != data.past_key || nullptr != data.present_key) {
if (nullptr != data.past_key && nullptr == data.present_key) {
qkv.k = const_cast<T*>(data.past_key);
qkv.v = const_cast<T*>(data.past_value);
data.k = const_cast<T*>(data.past_key);
data.v = const_cast<T*>(data.past_value);
} else if (nullptr == data.past_key && nullptr != data.present_key) {
if (qkv.format == AttentionQkvFormat::Q_K_V_BNSH) {
qkv.k = data.present_key;
qkv.v = data.present_value;
if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
data.k = data.present_key;
data.v = data.present_value;
} else {
assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH);
qkv.k = data.temp_k_workspace;
qkv.v = data.temp_v_workspace;
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
data.k = data.temp_k_workspace;
data.v = data.temp_v_workspace;
}
} else if (pass_past_in_kv) {
// past_key and past_value are used directly as key and value in attention computations
qkv.k = const_cast<T*>(data.past_key);
qkv.v = const_cast<T*>(data.past_value);
data.k = const_cast<T*>(data.past_key);
data.v = const_cast<T*>(data.past_value);

// This path has a memory copy from past_key and past_value to present_key and present_value
// Avoid this path since the memory copy is unnecessary because past_key == present_key and
Expand All @@ -298,14 +299,14 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int
ORT_RETURN_IF_ERROR(
LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
batch_size, qk_head_size, num_heads,
max_threads_per_block, 1, data.past_key, qkv.k, data.present_key));
max_threads_per_block, 1, data.past_key, data.k, data.present_key));
ORT_RETURN_IF_ERROR(
LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
batch_size, v_head_size, num_heads,
max_threads_per_block, 1, data.past_value, qkv.v, data.present_value));
max_threads_per_block, 1, data.past_value, data.v, data.present_value));
// Update pointers to present_k and present_v.
qkv.k = data.present_key;
qkv.v = data.present_value;
data.k = data.present_key;
data.v = data.present_value;
}
}

Expand All @@ -317,15 +318,147 @@ template Status ConcatPastToPresent<float>(int batch_size, int num_heads, int qk
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<float>& data,
QkvData<float>& qkv);
AttentionData<float>& data);

template Status ConcatPastToPresent<half>(int batch_size, int num_heads, int qk_head_size, int v_head_size,
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<half>& data,
QkvData<half>& qkv);
AttentionData<half>& data);

// ----------------------------------------------------------------------------------
// Below kernels are for past and present sharing buffer
// ----------------------------------------------------------------------------------

template <typename T>
__global__ void AddBiasTransAppendKvToPresentSmall(
const T* qkv, const T* biases, T* present,
const int head_size, const int past_sequence_length, const int max_sequence_length) {
// Input: BxSxMxNxH (Format 1)
// Output: (2, B, N, [P..P+S) of MaxS, H),
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int N = blockDim.y;
const int S = gridDim.x;
const int B = gridDim.y;

constexpr int M = 3; // Matrix count in qkv
const int m = blockIdx.z + 1; // k = 1, v = 2

const int NH = N * head_size;
const int NHS = NH * S;

qkv += (n * head_size + (s * M + m) * NH + b * M * NHS);
if (biases) {
biases += (m * NH + n * head_size);
}

const int MsH = max_sequence_length * head_size;
const int NMsH = N * MsH;
const int BNMsH = B * NMsH;
present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH);

for (int h = threadIdx.x; h < head_size; h += blockDim.x) {
T bias = (biases ? biases[h] : (T)0.0f);
present[h] = qkv[h] + bias;
}
}

template <typename T>
__global__ void AddBiasTransAppendKvToPresent(
const T* qkv, const T* biases, T* present,
const int head_size, const int past_sequence_length, const int max_sequence_length) {
// Input: BxSxMxNxH (Format 1)
// Output: (2, B, N, [P..P+S) of MaxS, H),
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
const int n = blockIdx.x;
const int s = blockIdx.y;
const int b = (blockIdx.z >> 1);
const int N = gridDim.x;
const int S = gridDim.y;
const int B = (gridDim.z >> 1);

constexpr int M = 3; // Matrix count in qkv
const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2

const int NH = N * head_size;
const int NHS = NH * S;

qkv += (n * head_size + (s * M + m) * NH + b * M * NHS);
if (biases) {
biases += (m * NH + n * head_size);
}

const int MsH = max_sequence_length * head_size;
const int NMsH = N * MsH;
const int BNMsH = B * NMsH;
present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH);

for (int h = threadIdx.x; h < head_size; h += blockDim.x) {
T bias = (biases ? biases[h] : (T)0.0f);
present[h] = qkv[h] + bias;
}
}

// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3.
// bias is of shape (3, NxH) or nullptr
// append to present of (2, B, N, (P..T) of M, H),
template <typename T>
Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
const int max_sequence_length,
const int past_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const T* biases,
const T* qkv_buffer,
T* present) {
assert(head_size <= (1 << 30));

int64_t nh = (int64_t)head_size * num_heads;
if (nh <= max_threads_per_block) {
const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);

AddBiasTransAppendKvToPresentSmall<T><<<grid, block, 0, stream>>>(
qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length);
} else {
const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v
const dim3 block(std::min(head_size, max_threads_per_block), 1, 1);
AddBiasTransAppendKvToPresent<T><<<grid, block, 0, stream>>>(
qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length);
}

return CUDA_CALL(cudaGetLastError());
}

template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
const int max_sequence_length,
const int total_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const float* bias,
const float* qkv_buffer,
float* present);

template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
const int max_sequence_length,
const int total_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const half* bias,
const half* qkv_buffer,
half* present);
#endif

} // namespace cuda
Expand Down
Loading

0 comments on commit 730fab3

Please sign in to comment.