Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Bob-Chen222 committed Nov 8, 2024
1 parent 3b34a5b commit 7d612f7
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m,
cudaStream_t stream);

template <typename DT>
void update_qkv_in_batch_verify(IncMultiHeadSelfAttentionMeta const *m,
void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream, bool is_spec = true);
cudaStream_t stream, bool is_spec);

// [For the tokens in streaming cache]
// Convert the out-of-order cache to in-order relative position.
Expand Down
2 changes: 1 addition & 1 deletion src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m,
apply_pos_encoding_to_tokens_in_batch(
m, bc, static_cast<DT *>(m->devQKVProjArray), stream);
// Move the batch qkv values to where took by attention
update_qkv_in_batch_verify<DT>(m, bc, stream, false);
update_qkv_in_batch<DT>(m, bc, stream, false);
}

// phase 4: Attention computation
Expand Down
10 changes: 5 additions & 5 deletions src/ops/kernels/inc_multihead_self_attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m,
}
template <typename DT>
__global__ void update_qkv_in_batch_verify_kernel(
__global__ void update_qkv_in_batch_paged_kernel(
DT *qkv_proj_array,
half *qTmp_ptr,
half *kvCache_ptr,
Expand Down Expand Up @@ -580,7 +580,7 @@ __global__ void update_qkv_in_batch_verify_kernel(
}
template <typename DT>
void update_qkv_in_batch_verify(IncMultiHeadSelfAttentionMeta const *m,
void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream, bool is_spec) {
// printf("entered update_qkv_in_batch_verify\n");
Expand All @@ -593,7 +593,7 @@ void update_qkv_in_batch_verify(IncMultiHeadSelfAttentionMeta const *m,
: m->handle.incr_attention_metadata->kv_indptr;
int32_t *kv_indices = is_spec ? m->handle.tree_verify_attention_metadata->kv_indices
: m->handle.incr_attention_metadata->kv_indices;
update_qkv_in_batch_verify_kernel<<<GET_BLOCKS(parallelism),
update_qkv_in_batch_paged_kernel<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
stream>>>(
Expand Down Expand Up @@ -1040,12 +1040,12 @@ template void Kernels::IncMultiHeadAttention::update_qkv_in_batch<half>(
BatchConfig const *bc,
cudaStream_t stream);
template void Kernels::IncMultiHeadAttention::update_qkv_in_batch_verify<float>(
template void Kernels::IncMultiHeadAttention::update_qkv_in_batch<float>(
IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream, bool is_spec);
template void Kernels::IncMultiHeadAttention::update_qkv_in_batch_verify<half>(
template void Kernels::IncMultiHeadAttention::update_qkv_in_batch<half>(
IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream, bool is_spec);
Expand Down
2 changes: 1 addition & 1 deletion src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m,
// cudaEventRecord(t_start, stream);
// Update key-val cache, compact q array
update_qkv_in_batch_verify<DT>(m, bc, stream);
update_qkv_in_batch<DT>(m, bc, stream, true);
// cudaEventRecord(t_end, stream);
// checkCUDA(cudaEventSynchronize(t_end));
Expand Down
1 change: 0 additions & 1 deletion src/runtime/request_manager.cu
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,6 @@ void RequestManager::load_batch_config_task(
round_up_pages(BatchConfig::max_sequence_length() +
BatchConfig::max_spec_tree_token_num());
// int parallelism = batch_size;
prepare_inference_params_kernel_h(batch_config,
pm,
handle.incr_attention_metadata,
Expand Down

0 comments on commit 7d612f7

Please sign in to comment.