Skip to content

Commit

Permalink
lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Mar 22, 2024
1 parent 7908046 commit 284030a
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,8 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i
// Kernel to unpack qkv from packed qkv
template <typename T>
__global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads,
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size) {
const int kv_num_heads, const int head_size, const int sequence_length,
const int batch_size) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
int d = (num_heads + 2 * kv_num_heads) * head_size;
const int qkv_size = batch_size * sequence_length * d;
Expand Down Expand Up @@ -506,12 +507,14 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp
cudaStream_t stream, const int max_threads_per_block) {
const int threads = max_threads_per_block;
const int blocks = (batch_size * sequence_length * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads;

Check warning on line 509 in onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu:509: Lines should be <= 120 characters long [whitespace/line_length] [2]
UnpackQKV<<<blocks, threads, 0, stream>>>(packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads, head_size, sequence_length, batch_size);
UnpackQKV<<<blocks, threads, 0, stream>>>(packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads,
head_size, sequence_length, batch_size);
return CUDA_CALL(cudaGetLastError());
}

// Kernel to convert seqlens_k to position_ids
__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen, const int batch_size) {
__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen,
const int batch_size) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
int b = tid / seqlen;
int s = tid % seqlen;
Expand All @@ -533,8 +536,8 @@ __global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids,
}

// Convert seqlens_k to position_ids
Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, int64_t* position_ids,
cudaStream_t stream, const int max_threads_per_block) {
Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k,
int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) {
const int seqlen = parameters.sequence_length;
const int batch_size = parameters.batch_size;
const int threads = max_threads_per_block;
Expand Down Expand Up @@ -596,7 +599,8 @@ Status FlashAttention(
seqlens_k = data.seqlens_k_total;
}
} else if (!parameters.kv_share_buffer) { // copy past kv to present kv
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, nullptr, nullptr, stream, max_threads_per_block, true));
ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, nullptr, nullptr, stream, max_threads_per_block,
true));
}

void* present_key = reinterpret_cast<void*>(const_cast<T*>(data.present_key));
Expand Down Expand Up @@ -656,8 +660,8 @@ Status EfficientAttention(
auto q = reinterpret_cast<T*>(data.unpacked_qkv_buffer);
auto k = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size);
auto v = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size + k_size);
ORT_RETURN_IF_ERROR(LaunchUnpackQKV(reinterpret_cast<const T*>(data.query), q, k, v, num_heads, kv_num_heads, head_size,
sequence_length, batch_size, stream, max_threads_per_block));
ORT_RETURN_IF_ERROR(LaunchUnpackQKV(reinterpret_cast<const T*>(data.query), q, k, v, num_heads, kv_num_heads,
head_size, sequence_length, batch_size, stream, max_threads_per_block));
query = reinterpret_cast<const void*>(q);
key = reinterpret_cast<const void*>(k);
value = reinterpret_cast<const void*>(v);
Expand All @@ -669,18 +673,25 @@ Status EfficientAttention(
auto q_buffer = reinterpret_cast<T*>(data.rotary_buffer);
auto k_buffer = q_buffer + q_size;
auto position_ids_buff = reinterpret_cast<int64_t*>(k_buffer + k_size);
ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, data.seqlens_k, position_ids_buff, stream, max_threads_per_block));
ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, data.seqlens_k, position_ids_buff, stream,
max_threads_per_block));
DUMP_TENSOR_INIT();
DUMP_TENSOR("position_ids", position_ids_buff, batch_size, sequence_length);
// Launch rotary embedding kernel
ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel<T>(stream, q_buffer, reinterpret_cast<const T*>(query),
position_ids_buff, data.cos_cache, data.sin_cache, parameters.batch_size, parameters.sequence_length,
parameters.num_heads, parameters.head_size, parameters.rotary_dim, parameters.seqlen_present_kv_cache,
/*position_ids_format*/ 1, parameters.rotary_interleaved, device_prop.maxThreadsPerBlock, /*transposed*/ false));
position_ids_buff, data.cos_cache, data.sin_cache,
parameters.batch_size, parameters.sequence_length,
parameters.num_heads, parameters.head_size,
parameters.rotary_dim, parameters.seqlen_present_kv_cache,
/*position_ids_format*/ 1, parameters.rotary_interleaved,
device_prop.maxThreadsPerBlock, /*transposed*/ false));
ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel<T>(stream, k_buffer, reinterpret_cast<const T*>(key),
position_ids_buff, data.cos_cache, data.sin_cache, parameters.batch_size, parameters.sequence_length,
parameters.kv_num_heads, parameters.head_size, parameters.rotary_dim, parameters.seqlen_present_kv_cache,
/*position_ids_format*/ 1, parameters.rotary_interleaved, device_prop.maxThreadsPerBlock, /*transposed*/ false));
position_ids_buff, data.cos_cache, data.sin_cache,
parameters.batch_size, parameters.sequence_length,
parameters.kv_num_heads, parameters.head_size,
parameters.rotary_dim, parameters.seqlen_present_kv_cache,
/*position_ids_format*/ 1, parameters.rotary_interleaved,
device_prop.maxThreadsPerBlock, /*transposed*/ false));
query = reinterpret_cast<const void*>(q_buffer);
key = reinterpret_cast<const void*>(k_buffer);
}
Expand All @@ -689,7 +700,8 @@ Status EfficientAttention(
// Launch kernel to copy seqlen
constexpr int thr_per_blk = 256;
int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk;
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k_total, parameters.sequence_length, batch_size);
repeat_seqlen<<<blk_in_grid, thr_per_blk, 0, stream>>>(data.seqlens_k_total, parameters.sequence_length,
batch_size);
} else {
ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256));
}
Expand Down

0 comments on commit 284030a

Please sign in to comment.