diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h index 9ca48f69f6..2c73da124a 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h @@ -113,9 +113,9 @@ void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m, cudaStream_t stream); template -void update_qkv_in_batch_verify(IncMultiHeadSelfAttentionMeta const *m, +void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, - cudaStream_t stream, bool is_spec = true); + cudaStream_t stream, bool is_spec); // [For the tokens in streaming cache] // Convert the out-of-order cache to in-order relative position. diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index dfe0ad7ec6..dfa3e140e5 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -275,7 +275,7 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, apply_pos_encoding_to_tokens_in_batch( m, bc, static_cast
(m->devQKVProjArray), stream); // Move the batch qkv values to where took by attention - update_qkv_in_batch_verify
(m, bc, stream, false); + update_qkv_in_batch
(m, bc, stream, false); } // phase 4: Attention computation diff --git a/src/ops/kernels/inc_multihead_self_attention_kernels.cu b/src/ops/kernels/inc_multihead_self_attention_kernels.cu index dcd2bf8e98..63472bcb0c 100644 --- a/src/ops/kernels/inc_multihead_self_attention_kernels.cu +++ b/src/ops/kernels/inc_multihead_self_attention_kernels.cu @@ -514,7 +514,7 @@ void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m, } template -__global__ void update_qkv_in_batch_verify_kernel( +__global__ void update_qkv_in_batch_paged_kernel( DT *qkv_proj_array, half *qTmp_ptr, half *kvCache_ptr, @@ -580,7 +580,7 @@ __global__ void update_qkv_in_batch_verify_kernel( } template -void update_qkv_in_batch_verify(IncMultiHeadSelfAttentionMeta const *m, +void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, cudaStream_t stream, bool is_spec) { // printf("entered update_qkv_in_batch_verify\n"); @@ -593,7 +593,7 @@ void update_qkv_in_batch_verify(IncMultiHeadSelfAttentionMeta const *m, : m->handle.incr_attention_metadata->kv_indptr; int32_t *kv_indices = is_spec ? m->handle.tree_verify_attention_metadata->kv_indices : m->handle.incr_attention_metadata->kv_indices; - update_qkv_in_batch_verify_kernel<<>>( @@ -1040,12 +1040,12 @@ template void Kernels::IncMultiHeadAttention::update_qkv_in_batch( BatchConfig const *bc, cudaStream_t stream); -template void Kernels::IncMultiHeadAttention::update_qkv_in_batch_verify( +template void Kernels::IncMultiHeadAttention::update_qkv_in_batch( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, cudaStream_t stream, bool is_spec); -template void Kernels::IncMultiHeadAttention::update_qkv_in_batch_verify( +template void Kernels::IncMultiHeadAttention::update_qkv_in_batch( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, cudaStream_t stream, bool is_spec); diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 905df573a7..6846c048a5 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -433,7 +433,7 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // cudaEventRecord(t_start, stream); // Update key-val cache, compact q array - update_qkv_in_batch_verify
(m, bc, stream); + update_qkv_in_batch
(m, bc, stream, true); // cudaEventRecord(t_end, stream); // checkCUDA(cudaEventSynchronize(t_end)); diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index f6935dc0a5..ed44e8944a 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -460,7 +460,6 @@ void RequestManager::load_batch_config_task( round_up_pages(BatchConfig::max_sequence_length() + BatchConfig::max_spec_tree_token_num()); - // int parallelism = batch_size; prepare_inference_params_kernel_h(batch_config, pm, handle.incr_attention_metadata,