Skip to content

Commit

Permalink
workable incrdecoding!
Browse files Browse the repository at this point in the history
  • Loading branch information
Bob-Chen222 committed Nov 7, 2024
1 parent 8c203ec commit 3c158f8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
2 changes: 0 additions & 2 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1311,8 +1311,6 @@ BatchConfig RequestManager::prepare_decoding_batch() {
bc.requestsInfo[request_index].num_kv_pages = get_num_blocks_allocated(request);
bc.requestsInfo[request_index].kv_last_page_len = get_len_last_block(request);
bc.requestsInfo[request_index].request_guid = request.guid;
printf("Request %d, token %d, idx_to_physical %d\n", request.guid, request.tokens.back(), idx_to_physical);
printf("Request %d, num_kv_pages %d, kv_last_page_len %d\n", request.guid, bc.requestsInfo[request_index].num_kv_pages, bc.requestsInfo[request_index].kv_last_page_len);

bc.num_tokens++;

Expand Down
16 changes: 8 additions & 8 deletions src/runtime/request_manager.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void RequestManager::load_tokens_task(

void prepare_inference_params_kernel_h(BatchConfig const *batch_config,
PageManager *pm,
FFHandler handle,
AttentionMetaData *attention_metadata,
cudaStream_t stream,
uint32_t const max_num_pages,
int32_t *q_indptr_h,
Expand Down Expand Up @@ -130,28 +130,28 @@ void prepare_inference_params_kernel_h(BatchConfig const *batch_config,
}

// do the copy
checkCUDA(cudaMemcpyAsync(handle.tree_verify_attention_metadata->kv_indices,
checkCUDA(cudaMemcpyAsync(attention_metadata->kv_indices,
kv_indices_h,
sizeof(int32_t) * batch_size * max_num_pages,
cudaMemcpyHostToDevice,
stream));
checkCUDA(
cudaMemcpyAsync(handle.tree_verify_attention_metadata->kv_last_page_len,
cudaMemcpyAsync(attention_metadata->kv_last_page_len,
kv_last_page_len_h,
sizeof(int32_t) * batch_size,
cudaMemcpyHostToDevice,
stream));
checkCUDA(cudaMemcpyAsync(handle.tree_verify_attention_metadata->q_indptr,
checkCUDA(cudaMemcpyAsync(attention_metadata->q_indptr,
q_indptr_h,
sizeof(int32_t) * (batch_size + 1),
cudaMemcpyHostToDevice,
stream));
checkCUDA(cudaMemcpyAsync(handle.tree_verify_attention_metadata->kv_indptr,
checkCUDA(cudaMemcpyAsync(attention_metadata->kv_indptr,
kv_indptr_h,
sizeof(int32_t) * (batch_size + 1),
cudaMemcpyHostToDevice,
stream));
checkCUDA(cudaMemcpyAsync(handle.tree_verify_attention_metadata->qk_indptr,
checkCUDA(cudaMemcpyAsync(attention_metadata->qk_indptr,
qk_indptr_h,
sizeof(int32_t) * (batch_size + 1),
cudaMemcpyHostToDevice,
Expand Down Expand Up @@ -463,7 +463,7 @@ void RequestManager::load_batch_config_task(
// int parallelism = batch_size;
prepare_inference_params_kernel_h(batch_config,
pm,
handle,
handle.incr_attention_metadata,
stream,
max_num_pages,
q_indptr_h,
Expand Down Expand Up @@ -726,7 +726,7 @@ void RequestManager::load_batch_config_task(
// int parallelism = batch_size;
prepare_inference_params_kernel_h(batch_config,
pm,
handle,
handle.tree_verify_attention_metadata,
stream,
max_num_pages,
q_indptr_h,
Expand Down

0 comments on commit 3c158f8

Please sign in to comment.