diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index fccf2b6f07..d4c2e38e24 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -114,9 +114,9 @@ class BatchConfig { int first_token_index_in_request = -1; int first_token_offset_in_batch = -1; int num_tokens_in_batch = 0; - int padding = 0; // Padding for memory pointer alignment - int num_kv_pages; //number of kv pages used - int kv_last_page_len; //last page length of kv + int padding = 0; // Padding for memory pointer alignment + int num_kv_pages; // number of kv pages used + int kv_last_page_len; // last page length of kv RequestGuid request_guid; }; diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 32177a3837..bff5d28026 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -1081,6 +1081,10 @@ class FFModel { CompMode comp_mode = COMP_MODE_TRAINING); void compile_inference(); void set_transformer_layer_id(int id); + void set_num_transformer_layers(int num_layers); + void set_num_kv_heads(int num_heads); + void set_qkv_dim(int qkv_dim); + void set_size_dt(int size_dt); void set_position_offset(int offset); void graph_optimize(size_t budget, bool only_data_parallel, @@ -1142,6 +1146,8 @@ class FFModel { size_t tensor_global_guid, parallel_tensor_global_guid, node_global_guid; size_t current_transformer_layer_id; // positional embedding start offset + int num_transformer_layers; + int num_kv_heads, qkv_dim, size_dt; int position_offset; FFConfig config; FFIterationConfig iter_config; diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h index fe8d32387b..919393985d 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h @@ -15,20 +15,25 @@ namespace Kernels { namespace IncMultiHeadAttention { // kv layout: [num_pages, 2, page_size, num_kv_heads, head_dim] -__device__ __forceinline__ size_t get_k_entry_offset_verify(int const token_idx, - int const page_idx, - int const num_heads, - int const head_dim) { - size_t index = ((page_idx) * kPagesize * 2 + (token_idx % kPagesize)) * head_dim * num_heads; +__device__ __forceinline__ size_t + get_k_entry_offset_verify(int const token_idx, + int const page_idx, + int const num_heads, + int const head_dim) { + size_t index = ((page_idx)*kPagesize * 2 + (token_idx % kPagesize)) * + head_dim * num_heads; return index; } // kv layout: [num_pages, 2, page_size, num_kv_heads, head_dim] -__device__ __forceinline__ size_t get_v_entry_offset_verify(int const token_idx, - int const page_idx, - int const num_heads, - int const head_dim) { - size_t index = ((page_idx) * kPagesize * 2 + kPagesize + (token_idx % kPagesize)) * head_dim * num_heads; +__device__ __forceinline__ size_t + get_v_entry_offset_verify(int const token_idx, + int const page_idx, + int const num_heads, + int const head_dim) { + size_t index = + ((page_idx)*kPagesize * 2 + kPagesize + (token_idx % kPagesize)) * + head_dim * num_heads; return index; } diff --git a/include/flexflow/page_manager.h b/include/flexflow/page_manager.h index 54b661e026..87a544d819 100644 --- a/include/flexflow/page_manager.h +++ b/include/flexflow/page_manager.h @@ -1,14 +1,14 @@ #pragma once #include "flexflow/batch_config.h" +#include "flexflow/config.h" #include "flexflow/inference.h" #include "flexflow/model.h" -#include "flexflow/config.h" #include "flexflow/utils/file_loader.h" +#include #include #include #include -#include namespace FlexFlow { @@ -20,118 +20,143 @@ using TokenId = BatchConfig::TokenId; */ class LogicalTokenBlock { public: - using TokenId = BatchConfig::TokenId; + using TokenId = BatchConfig::TokenId; - // Constructor - LogicalTokenBlock(int block_number, uint32_t block_size); + // Constructor + LogicalTokenBlock(int block_number, uint32_t block_size); - // Method to check if the block is empty - bool is_empty() const; + // Method to check if the block is empty + bool is_empty() const; - // Method to check if the block is full - bool is_full() const; + // Method to check if the block is full + bool is_full() const; - // Method to get the number of empty slots - int get_num_empty_slots() const; + // Method to get the number of empty slots + int get_num_empty_slots() const; - // Method to get the number of allocated slots - int get_num_alloc_slots() const; + // Method to get the number of allocated slots + int get_num_alloc_slots() const; - // Used to clean up the spec tokens in a block since these spec tokens may not be committed after use - void reset_num_spec_tokens(); + // Used to clean up the spec tokens in a block since these spec tokens may not + // be committed after use + void reset_num_spec_tokens(); - // Method to append tokens - void append_tokens(const std::vector& token_ids_to_append, bool committed); + // Method to append tokens + void append_tokens(std::vector const &token_ids_to_append, + bool committed); - int get_num_tokens() const { return num_tokens; } - int get_num_commit_tokens() const { return num_commit_tokens; } - int get_num_spec_tokens() const { return num_spec_tokens; } + int get_num_tokens() const { + return num_tokens; + } + int get_num_commit_tokens() const { + return num_commit_tokens; + } + int get_num_spec_tokens() const { + return num_spec_tokens; + } - std::vector get_token_ids() const; + std::vector get_token_ids() const; private: - int block_number; // the index of the logical token block - int block_size; // the size of the block - int num_tokens; // the number of tokens currently stored in the block - int num_commit_tokens; // the number of tokens inside this block that are already committed - int num_spec_tokens; // the number of tokens inside this block that are speculative tokens, which is stored temporarily - std::vector token_ids; //store the token ids in a order that corresponds to the inference sequence + int block_number; // the index of the logical token block + int block_size; // the size of the block + int num_tokens; // the number of tokens currently stored in the block + int num_commit_tokens; // the number of tokens inside this block that are + // already committed + int num_spec_tokens; // the number of tokens inside this block that are + // speculative tokens, which is stored temporarily + std::vector token_ids; // store the token ids in a order that + // corresponds to the inference sequence }; /** * @class PhysicalTokenBlock - * @brief A class to represent a physical block of tokens similar to physical memory address - * It keeps track of the location of the tokens stored on GPU memory + * @brief A class to represent a physical block of tokens similar to physical + * memory address It keeps track of the location of the tokens stored on GPU + * memory */ class PhysicalTokenBlock { public: - // Constructor - PhysicalTokenBlock(int block_number, int block_size); - - // Method to get the block number - int get_block_number() const { return block_number; } - void incr_ref_count() { ref_count++; } - void decr_ref_count() { ref_count--; } - int ref_count; // reference count, TODO: move to private + // Constructor + PhysicalTokenBlock(int block_number, int block_size); + + // Method to get the block number + int get_block_number() const { + return block_number; + } + void incr_ref_count() { + ref_count++; + } + void decr_ref_count() { + ref_count--; + } + int ref_count; // reference count, TODO: move to private private: - int block_number; // the index of the physical token block - int block_size; // the size of the block + int block_number; // the index of the physical token block + int block_size; // the size of the block }; /** * @class BlockAllocator - * @brief A Block Manager that is reponsible for maintaining a pool of free blocks + * @brief A Block Manager that is reponsible for maintaining a pool of free + * blocks */ class BlockAllocator { public: - // Constructor - BlockAllocator(int block_size, int num_total_blocks); + // Constructor + BlockAllocator(int block_size, int num_total_blocks); - // Allocate a block - PhysicalTokenBlock allocate(); + // Allocate a block + PhysicalTokenBlock allocate(); - // Free a block - void free(PhysicalTokenBlock& block); + // Free a block + void free(PhysicalTokenBlock &block); - // Get the number of free blocks - int get_num_free_blocks() const; + // Get the number of free blocks + int get_num_free_blocks() const; private: - int block_size; - int num_total_blocks; - std::deque free_blocks; + int block_size; + int num_total_blocks; + std::deque free_blocks; }; /* -* @class PageManager -* @brief A wrapper class that manages the kv cache allocation status -* notice that all the layers of model will share the same page manager because the position of kv cache will be the same -*/ + * @class PageManager + * @brief A wrapper class that manages the kv cache allocation status + * notice that all the layers of model will share the same page manager because + * the position of kv cache will be the same + */ class PageManager { public: - // Get the singleton instance of the PageManager as it will be shared in multiple places - static PageManager *get_page_manager(); - using BlockTable = std::vector; - using RequestGuid = BatchConfig::RequestGuid; - PageManager(int block_size, int num_total_blocks); - - int allocate_one_block(const RequestGuid& request_guid); - void free_request(const RequestGuid& request_guid); - //used for the case that we want to free the last num_blocks that stores spec tokens(which are the tokens are not yet committed) - void free_multiple_blocks(const RequestGuid& request_guid, int num_blocks); - std::vector get_block_table_indices(const RequestGuid& request_guid) const; - - - void free_block_table(BlockTable& block_table); -private: - int block_size; // the size of the block - int num_total_blocks; // the total number of blocks - BlockAllocator block_allocator; - std::unordered_map block_tables; + // Get the singleton instance of the PageManager as it will be shared in + // multiple places + static PageManager *get_page_manager(); + static PageManager *get_page_manager(FFModel *ff, int kv_cache_size); + using BlockTable = std::vector; + using RequestGuid = BatchConfig::RequestGuid; + PageManager(int block_size, int num_total_blocks); + int allocate_one_block(RequestGuid const &request_guid); + void free_request(RequestGuid const &request_guid); + // used for the case that we want to free the last num_blocks that stores spec + // tokens(which are the tokens are not yet committed) + void free_multiple_blocks(RequestGuid const &request_guid, int num_blocks); + std::vector + get_block_table_indices(RequestGuid const &request_guid) const; + + void free_block_table(BlockTable &block_table); - int get_num_total_free_blocks() const; - int get_num_allocated_blocks(const RequestGuid& request_guid) const; +private: + int num_transformer_layers; + int total_kv_cache_size; + int block_size; // the size of the block + int num_total_blocks; // the total number of blocks + BlockAllocator block_allocator; + std::unordered_map block_tables; + + int get_num_total_free_blocks() const; + int get_num_allocated_blocks(RequestGuid const &request_guid) const; }; }; // namespace FlexFlow \ No newline at end of file diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 8403fb8891..f7fe3c8725 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -18,8 +18,8 @@ #include "flexflow/batch_config.h" #include "flexflow/inference.h" #include "flexflow/model.h" -#include "flexflow/utils/file_loader.h" #include "flexflow/page_manager.h" +#include "flexflow/utils/file_loader.h" #include #include #include @@ -149,7 +149,8 @@ struct Request { Status status = PENDING; std::vector tokens; - //page attention, page_last_committed should be -1 because there are no blocks at the beginning + // page attention, page_last_committed should be -1 because there are no + // blocks at the beginning int page_last_committed = -1; std::vector blocks; @@ -539,8 +540,7 @@ class RequestManager { int get_len_last_block(Request &request) const; int get_idx_last_logical_token(Request &request) const; int idx_logical_to_physical(Request &request, int idx_logical); - void _append_block_to_request( - Request &request, bool is_commit); + void _append_block_to_request(Request &request, bool is_commit); int append_token_to_block(Request &request, TokenId token, bool is_commit); void reset_block_table(Request &request); void print_num_tokens(Request &request); diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index 1f7947a8cf..5a18daab47 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -209,7 +209,7 @@ void FlexFlow::top_level_task(Task const *task, int max_tokens_per_prefilling_batch = -1; int max_sequence_length = 256; int max_output_length = 512; - int max_kv_cache_size = -1; //if -1, then use the default value + int max_kv_cache_size = -1; // if -1, then use the default value RequestManager::DecodingMode decoding_mode = RequestManager::INCREMENTAL_DECODING; int sampling_seed = 0; diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index 28bd7d5740..9049b3885c 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -63,6 +63,11 @@ void FALCON::create_falcon_model(FFModel &ff, Tensor mha = nullptr, mlp_output = nullptr; Tensor res_ln_outputs[2] = {nullptr, nullptr}; + ff.set_num_transformer_layers(falcon_config.n_layer); + ff.set_num_kv_heads(falcon_config.n_head_kv); + ff.set_qkv_dim(falcon_config.hidden_size / falcon_config.n_head * 2); + ff.set_size_dt(data_type_size(input->data_type)); + for (int i = 0; i < falcon_config.n_layer; i++) { // set transformer layer id ff.set_transformer_layer_id(i); diff --git a/inference/models/falcon.h b/inference/models/falcon.h index 3934626337..a15c289917 100644 --- a/inference/models/falcon.h +++ b/inference/models/falcon.h @@ -16,6 +16,7 @@ // #include "file_loader.h" #include "flexflow/batch_config.h" +#include "flexflow/ffconst_utils.h" #include "flexflow/inference.h" #include "flexflow/request_manager.h" #include diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 988f8f4b53..92f1cdf763 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -64,6 +64,11 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor w2 = nullptr; + ff.set_num_transformer_layers(llama_config.num_hidden_layers); + ff.set_num_kv_heads(llama_config.num_key_value_heads); + ff.set_qkv_dim(llama_config.hidden_size / llama_config.num_attention_heads * + 2); + ff.set_size_dt(data_type_size(input->data_type)); for (int i = 0; i < llama_config.num_hidden_layers; i++) { // set transformer layer id ff.set_transformer_layer_id(i); diff --git a/inference/models/llama.h b/inference/models/llama.h index 3f11ca96d1..cd6f9c5cc8 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -16,6 +16,7 @@ // #include "file_loader.h" #include "flexflow/batch_config.h" +#include "flexflow/ffconst_utils.h" #include "flexflow/inference.h" #include "flexflow/request_manager.h" #include diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index a7bf79f6b8..b95cb5c91a 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -64,6 +64,10 @@ void MPT::create_mpt_model(FFModel &ff, Tensor intermediate_output = nullptr, layernorm_output = nullptr; Tensor res_ln_outputs[2] = {nullptr, nullptr}; + ff.set_num_transformer_layers(mpt_config.n_layers); + ff.set_num_kv_heads(mpt_config.n_heads); + ff.set_qkv_dim(mpt_config.hidden_size / mpt_config.n_heads * 2); + ff.set_size_dt(data_type_size(input->data_type)); for (int i = 0; i < mpt_config.n_layers; i++) { ff.set_transformer_layer_id(i); diff --git a/inference/models/mpt.h b/inference/models/mpt.h index bd7a9410f6..8466ea1cb2 100644 --- a/inference/models/mpt.h +++ b/inference/models/mpt.h @@ -16,6 +16,7 @@ // #include "file_loader.h" #include "flexflow/batch_config.h" +#include "flexflow/ffconst_utils.h" #include "flexflow/inference.h" #include "flexflow/request_manager.h" #include diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 25f9833a1d..352809ede5 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -77,6 +77,10 @@ void OPT::create_opt_model(FFModel &ff, Tensor fc2 = nullptr, added = nullptr; Tensor res_ln_outputs[2] = {nullptr, nullptr}; + ff.set_num_transformer_layers(opt_config.num_hidden_layers); + ff.set_num_kv_heads(opt_config.num_attention_heads); + ff.set_qkv_dim(opt_config.hidden_size / opt_config.num_attention_heads * 2); + ff.set_size_dt(data_type_size(input->data_type)); for (int i = 0; i < opt_config.num_hidden_layers; i++) { // set transformer layer id ff.set_transformer_layer_id(i); diff --git a/inference/models/opt.h b/inference/models/opt.h index 90443e872b..23ba8888bb 100644 --- a/inference/models/opt.h +++ b/inference/models/opt.h @@ -16,6 +16,7 @@ // #include "file_loader.h" #include "flexflow/batch_config.h" +#include "flexflow/ffconst_utils.h" #include "flexflow/inference.h" #include "flexflow/request_manager.h" #include diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index 31505b0ba8..401a754d03 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -82,6 +82,11 @@ void STARCODER::create_starcoder_model( Tensor residual = nullptr, c_proj = nullptr; Tensor res_ln_outputs[2] = {nullptr, nullptr}; + ff.set_num_transformer_layers(startcoder_config.num_hidden_layers); + ff.set_num_kv_heads(startcoder_config.num_attention_heads); + ff.set_qkv_dim(startcoder_config.hidden_size / + startcoder_config.num_attention_heads * 2); + ff.set_size_dt(data_type_size(input->data_type)); for (int i = 0; i < startcoder_config.num_hidden_layers; i++) { // set transformer layer id ff.set_transformer_layer_id(i); diff --git a/inference/models/starcoder.h b/inference/models/starcoder.h index 7241acde3a..57e1229f1a 100644 --- a/inference/models/starcoder.h +++ b/inference/models/starcoder.h @@ -16,6 +16,7 @@ // #include "file_loader.h" #include "flexflow/batch_config.h" +#include "flexflow/ffconst_utils.h" #include "flexflow/inference.h" #include "flexflow/request_manager.h" #include diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 5ec7185863..d1e9164611 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -48,19 +48,19 @@ using flashinfer::PosEncodingMode; using flashinfer::QKVLayout; template -__global__ void - update_qkv_in_batch_verify_kernel(DT *qkv_proj_array, - half *qTmp_ptr, - half *kvCache_ptr, - int32_t *kv_indptr, - int32_t *kv_page_indices, - bool const *request_available, - BatchConfig::PerTokenInfo const *tokenInfos, - int const max_num_pages, - int num_q_heads, - int num_kv_heads, - int head_dim, - int num_new_tokens) { +__global__ void update_qkv_in_batch_verify_kernel( + DT *qkv_proj_array, + half *qTmp_ptr, + half *kvCache_ptr, + int32_t *kv_indptr, + int32_t *kv_page_indices, + bool const *request_available, + BatchConfig::PerTokenInfo const *tokenInfos, + int const max_num_pages, + int num_q_heads, + int num_kv_heads, + int head_dim, + int num_new_tokens) { int const q_hidden_size = num_q_heads * head_dim; int const temp_kv_hidden_size = num_q_heads * head_dim; // temporary hard code int const kv_hidden_size = num_kv_heads * head_dim; @@ -68,7 +68,6 @@ __global__ void int const token_idx = thread_idx / q_hidden_size; int const offset = thread_idx % q_hidden_size; - if (token_idx >= num_new_tokens) { return; } @@ -76,7 +75,6 @@ __global__ void int const req_idx = tokenInfos[token_idx].request_index; int token_abs_idx = tokenInfos[token_idx].abs_index_in_request; - // calculate the compact request index in the easiest way // TODO: recheck int req_idx_compact = -1; @@ -98,13 +96,12 @@ __global__ void int start = kv_indptr[req_idx_compact]; int end = kv_indptr[req_idx_compact + 1] - 1; assert(start <= end && "Invalid kv_indptr"); - assert(start + (token_abs_idx / kPagesize) <= end && - "Invalid page index"); + assert(start + (token_abs_idx / kPagesize) <= end && "Invalid page index"); int page_idx = kv_page_indices[start + (token_abs_idx / kPagesize)]; size_t to_k_idx = get_k_entry_offset_verify( - token_abs_idx, page_idx, num_kv_heads, head_dim), + token_abs_idx, page_idx, num_kv_heads, head_dim), to_v_idx = get_v_entry_offset_verify( - token_abs_idx, page_idx, num_kv_heads, head_dim); + token_abs_idx, page_idx, num_kv_heads, head_dim); // key and value cache should be stored interleaved int const stride = num_q_heads / num_kv_heads; int const kv_offset = @@ -119,8 +116,8 @@ __global__ void template void update_qkv_in_batch_verify(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - cudaStream_t stream) { + BatchConfig const *bc, + cudaStream_t stream) { // printf("entered update_qkv_in_batch_verify\n"); int num_new_tokens = bc->num_active_tokens(); if (num_new_tokens == 0) { @@ -131,20 +128,21 @@ void update_qkv_in_batch_verify(IncMultiHeadSelfAttentionMeta const *m, round_up_pages(BatchConfig::max_sequence_length() + BatchConfig::max_spec_tree_token_num()); update_qkv_in_batch_verify_kernel<<>>(static_cast
(m->devQKVProjArray), - static_cast(m->queryTmp), - static_cast(m->kvCache), - m->handle.tree_verify_attention_metadata->kv_indptr, - m->handle.tree_verify_attention_metadata->kv_indices, - m->request_available, - m->token_infos, - max_num_pages, - m->num_q_heads, - m->num_kv_heads, - m->qk_dim, - num_new_tokens); + min(CUDA_NUM_THREADS, parallelism), + 0, + stream>>>( + static_cast
(m->devQKVProjArray), + static_cast(m->queryTmp), + static_cast(m->kvCache), + m->handle.tree_verify_attention_metadata->kv_indptr, + m->handle.tree_verify_attention_metadata->kv_indices, + m->request_available, + m->token_infos, + max_num_pages, + m->num_q_heads, + m->num_kv_heads, + m->qk_dim, + num_new_tokens); // cudaStreamSynchronize(stream); // printf("exited update_qkv_in_batch_verify\n"); } @@ -187,16 +185,29 @@ __global__ void commit_tokens_kernel( // int const req_id = committedTokenInfos[i].request_index; // int const tok_id = committedTokenInfos[i].token_depth; int const page_to_idx = committedTokenInfos[i].token_depth / kPagesize; - int const page_from_idx = committedTokenInfos[i].index_in_kv_cache / kPagesize; + int const page_from_idx = + committedTokenInfos[i].index_in_kv_cache / kPagesize; size_t from_k_idx = get_k_entry_offset_verify( - committedTokenInfos[i].index_in_kv_cache, page_from_idx, num_kv_heads, head_dim), + committedTokenInfos[i].index_in_kv_cache, + page_from_idx, + num_kv_heads, + head_dim), from_v_idx = get_v_entry_offset_verify( - committedTokenInfos[i].index_in_kv_cache, page_from_idx, num_kv_heads, head_dim); - size_t to_k_idx = get_k_entry_offset_verify( - committedTokenInfos[i].token_depth, page_to_idx, num_kv_heads, head_dim), - to_v_idx = get_v_entry_offset_verify( - committedTokenInfos[i].token_depth, page_to_idx, num_kv_heads, head_dim); + committedTokenInfos[i].index_in_kv_cache, + page_from_idx, + num_kv_heads, + head_dim); + size_t to_k_idx = + get_k_entry_offset_verify(committedTokenInfos[i].token_depth, + page_to_idx, + num_kv_heads, + head_dim), + to_v_idx = + get_v_entry_offset_verify(committedTokenInfos[i].token_depth, + page_to_idx, + num_kv_heads, + head_dim); kCache_ptr[to_k_idx + offset] = kCache_ptr[from_k_idx + offset]; kCache_ptr[to_v_idx + offset] = kCache_ptr[from_v_idx + offset]; @@ -220,16 +231,17 @@ void commit_tokens(TreeIncMultiHeadSelfAttentionMeta const *m, commit_tokens_kernel<<>>(static_cast(m->kvCache), - m->handle.tree_verify_attention_metadata->kv_indptr, - m->handle.tree_verify_attention_metadata->kv_indices, - m->committed_token_infos, - m->request_available, - num_requests, - m->num_kv_heads, - m->qk_dim, - m->num_tokens_to_commit, - max_num_pages); + stream>>>( + static_cast(m->kvCache), + m->handle.tree_verify_attention_metadata->kv_indptr, + m->handle.tree_verify_attention_metadata->kv_indices, + m->committed_token_infos, + m->request_available, + num_requests, + m->num_kv_heads, + m->qk_dim, + m->num_tokens_to_commit, + max_num_pages); // cudaEventRecord(t_end, stream); // checkCUDA(cudaEventSynchronize(t_end)); // float elapsed = 0; @@ -611,7 +623,8 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // } // cudaError_t err = cudaDeviceSynchronize(); // if (err != cudaSuccess) { - // std::cerr << "Kernel launch failed with error: " << cudaGetErrorString(err) << std::endl; + // std::cerr << "Kernel launch failed with error: " << + // cudaGetErrorString(err) << std::endl; // } } diff --git a/src/runtime/batch_config.cc b/src/runtime/batch_config.cc index ca5d08e986..426f848d96 100644 --- a/src/runtime/batch_config.cc +++ b/src/runtime/batch_config.cc @@ -311,7 +311,7 @@ void StreamingCacheInfo::reset_cache() { total_len = 0; } -//page attention: TODO: I think we just need to change the index +// page attention: TODO: I think we just need to change the index int StreamingCacheInfo::global_2_cache_index(int global_index) { if (global_index < sink_cache_size) { diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index dd13bb2e05..31c2a51cd0 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -493,6 +493,22 @@ void FFModel::set_transformer_layer_id(int id) { assert(id < MAX_NUM_TRANSFORMER_LAYERS); } +void FFModel::set_num_transformer_layers(int num_layers) { + num_transformer_layers = num_layers; +} + +void FFModel::set_num_kv_heads(int num_heads) { + num_kv_heads = num_heads; +} + +void FFModel::set_qkv_dim(int qkv_dim) { + qkv_dim = qkv_dim; +} + +void FFModel::set_size_dt(int size_dt) { + size_dt = size_dt; +} + void FFModel::set_position_offset(int offset) { assert(offset == 0 || offset == 2); position_offset = offset; diff --git a/src/runtime/page_manager.cc b/src/runtime/page_manager.cc index f59837f30e..c88c4f6cec 100644 --- a/src/runtime/page_manager.cc +++ b/src/runtime/page_manager.cc @@ -17,202 +17,234 @@ namespace FlexFlow { -// For all runtime functions, they share a single page manager for pages information +// For all runtime functions, they share a single page manager for pages +// information PageManager *page_manager_singleton = nullptr; // the interface of logicaltokenblock LogicalTokenBlock::LogicalTokenBlock(int block_number, uint32_t block_size) - : block_number(block_number), block_size(block_size), num_tokens(0), num_commit_tokens(0), num_spec_tokens(0) { - } + : block_number(block_number), block_size(block_size), num_tokens(0), + num_commit_tokens(0), num_spec_tokens(0) {} bool LogicalTokenBlock::is_empty() const { - assert(num_spec_tokens == 0 && num_commit_tokens == 0); - assert(num_tokens <= block_size); - return num_tokens == 0; + assert(num_spec_tokens == 0 && num_commit_tokens == 0); + assert(num_tokens <= block_size); + return num_tokens == 0; } bool LogicalTokenBlock::is_full() const { - assert(num_spec_tokens + num_commit_tokens == num_tokens); - assert(num_tokens <= block_size); - return num_tokens == block_size; + assert(num_spec_tokens + num_commit_tokens == num_tokens); + assert(num_tokens <= block_size); + return num_tokens == block_size; } int LogicalTokenBlock::get_num_empty_slots() const { - assert(num_spec_tokens + num_commit_tokens == num_tokens); - assert(num_tokens <= block_size); - return block_size - num_tokens; + assert(num_spec_tokens + num_commit_tokens == num_tokens); + assert(num_tokens <= block_size); + return block_size - num_tokens; } int LogicalTokenBlock::get_num_alloc_slots() const { - assert(num_spec_tokens + num_commit_tokens == num_tokens); - assert(num_tokens <= block_size); - return num_tokens; + assert(num_spec_tokens + num_commit_tokens == num_tokens); + assert(num_tokens <= block_size); + return num_tokens; } -void LogicalTokenBlock::reset_num_spec_tokens(){ - assert(num_spec_tokens + num_commit_tokens == num_tokens); - assert(num_tokens <= block_size); +void LogicalTokenBlock::reset_num_spec_tokens() { + assert(num_spec_tokens + num_commit_tokens == num_tokens); + assert(num_tokens <= block_size); - num_tokens -= num_spec_tokens; - num_spec_tokens = 0; + num_tokens -= num_spec_tokens; + num_spec_tokens = 0; - assert(num_spec_tokens + num_commit_tokens == num_tokens); - assert(num_tokens <= block_size); + assert(num_spec_tokens + num_commit_tokens == num_tokens); + assert(num_tokens <= block_size); } -void LogicalTokenBlock::append_tokens(const std::vector& token_ids_to_append, bool committed) { - assert(num_spec_tokens + num_commit_tokens == num_tokens); - assert(num_tokens <= block_size); - if (num_tokens + token_ids_to_append.size() > block_size) { - printf("block is full! Cannot append more tokens\n"); - throw std::runtime_error("Block is full! Cannot append more tokens."); - } - token_ids.insert(token_ids.end(), token_ids_to_append.begin(), token_ids_to_append.end()); - num_tokens += token_ids_to_append.size(); - if (committed) { - num_commit_tokens += token_ids_to_append.size(); - }else{ - num_spec_tokens += token_ids_to_append.size(); - } - assert(num_spec_tokens + num_commit_tokens == num_tokens); - assert(num_tokens <= block_size); +void LogicalTokenBlock::append_tokens( + std::vector const &token_ids_to_append, bool committed) { + assert(num_spec_tokens + num_commit_tokens == num_tokens); + assert(num_tokens <= block_size); + if (num_tokens + token_ids_to_append.size() > block_size) { + printf("block is full! Cannot append more tokens\n"); + throw std::runtime_error("Block is full! Cannot append more tokens."); + } + token_ids.insert( + token_ids.end(), token_ids_to_append.begin(), token_ids_to_append.end()); + num_tokens += token_ids_to_append.size(); + if (committed) { + num_commit_tokens += token_ids_to_append.size(); + } else { + num_spec_tokens += token_ids_to_append.size(); + } + assert(num_spec_tokens + num_commit_tokens == num_tokens); + assert(num_tokens <= block_size); } std::vector LogicalTokenBlock::get_token_ids() const { - return token_ids; + return token_ids; } PhysicalTokenBlock::PhysicalTokenBlock(int block_number, int block_size) : block_number(block_number), block_size(block_size), ref_count(0) {} BlockAllocator::BlockAllocator(int block_size, int num_total_blocks) { - for (int block_number = 0; block_number < num_total_blocks; ++block_number) { - free_blocks.push_back(PhysicalTokenBlock(block_number, block_size)); - } - num_total_blocks = num_total_blocks; + for (int block_number = 0; block_number < num_total_blocks; ++block_number) { + free_blocks.push_back(PhysicalTokenBlock(block_number, block_size)); + } + num_total_blocks = num_total_blocks; } // Allocate a block PhysicalTokenBlock BlockAllocator::allocate() { - if (free_blocks.empty()) { - printf("no free blocks are available\n"); - throw std::runtime_error("Out of memory! No free blocks are available."); - } - PhysicalTokenBlock block = free_blocks.front(); - free_blocks.pop_front(); - block.incr_ref_count(); - return block; + if (free_blocks.empty()) { + printf("no free blocks are available\n"); + throw std::runtime_error("Out of memory! No free blocks are available."); + } + PhysicalTokenBlock block = free_blocks.front(); + free_blocks.pop_front(); + block.incr_ref_count(); + return block; } // Free a block -void BlockAllocator::free(PhysicalTokenBlock& block) { - if (block.ref_count == 0) { - printf("block is already freed\n"); - throw std::runtime_error("Double free! Block is already freed."); - } - block.decr_ref_count(); - if (block.ref_count == 0) { - // printf("put block number: %d back to free_blocks\n", block.get_block_number()); - free_blocks.push_back(block); - }else{ - // in current implementation this should not be the case - printf("block is not freed. Ref count: %d\n", block.ref_count); - throw std::runtime_error("Block is not freed. Ref count: " + std::to_string(block.ref_count)); - } +void BlockAllocator::free(PhysicalTokenBlock &block) { + if (block.ref_count == 0) { + printf("block is already freed\n"); + throw std::runtime_error("Double free! Block is already freed."); + } + block.decr_ref_count(); + if (block.ref_count == 0) { + // printf("put block number: %d back to free_blocks\n", + // block.get_block_number()); + free_blocks.push_back(block); + } else { + // in current implementation this should not be the case + printf("block is not freed. Ref count: %d\n", block.ref_count); + throw std::runtime_error("Block is not freed. Ref count: " + + std::to_string(block.ref_count)); + } } int BlockAllocator::get_num_free_blocks() const { - return free_blocks.size(); + return free_blocks.size(); } PageManager::PageManager(int block_size, int num_total_blocks) : block_size(block_size), num_total_blocks(num_total_blocks), - block_allocator(block_size, num_total_blocks) { - } + block_allocator(block_size, num_total_blocks) {} -//return the physical number of this block -int PageManager::allocate_one_block(const RequestGuid& request_guid) { - BlockTable& block_table = block_tables[request_guid]; +// return the physical number of this block +int PageManager::allocate_one_block(RequestGuid const &request_guid) { + BlockTable &block_table = block_tables[request_guid]; - PhysicalTokenBlock block = block_allocator.allocate(); - block_table.push_back(block); - block_tables[request_guid] = block_table; - return block.get_block_number(); + PhysicalTokenBlock block = block_allocator.allocate(); + block_table.push_back(block); + block_tables[request_guid] = block_table; + return block.get_block_number(); } -void PageManager::free_block_table(BlockTable& block_table) { - // make it reverse order to free the last allocated block first - BlockTable::reverse_iterator rit = block_table.rbegin(); - for (; rit != block_table.rend(); ++rit) { - block_allocator.free(*rit); - } - return; +void PageManager::free_block_table(BlockTable &block_table) { + // make it reverse order to free the last allocated block first + BlockTable::reverse_iterator rit = block_table.rbegin(); + for (; rit != block_table.rend(); ++rit) { + block_allocator.free(*rit); + } + return; } -void PageManager::free_request(const RequestGuid& request_guid) { - //we only free the blocks that are already used - assert(block_tables.find(request_guid) != block_tables.end()); - BlockTable block_table = block_tables[request_guid]; - free_block_table(block_table); - block_tables.erase(request_guid); - return; +void PageManager::free_request(RequestGuid const &request_guid) { + // we only free the blocks that are already used + assert(block_tables.find(request_guid) != block_tables.end()); + BlockTable block_table = block_tables[request_guid]; + free_block_table(block_table); + block_tables.erase(request_guid); + return; } // delete the last num_blocks in the request_guid -void PageManager::free_multiple_blocks(const RequestGuid& request_guid, int num_blocks) { - assert(block_tables.find(request_guid) != block_tables.end()); - auto& block_table = block_tables[request_guid]; - assert(num_blocks <= block_table.size()); - int num_blocks_allocated = block_table.size(); - for (int i = 0; i < num_blocks; i++) { - block_allocator.free(block_table[num_blocks_allocated - i - 1]); - } - // only keep the first num_blocks_allocated - num_blocks blocks - block_table.erase(block_table.begin() + num_blocks_allocated - num_blocks, block_table.end()); - block_tables[request_guid] = block_table; - return; +void PageManager::free_multiple_blocks(RequestGuid const &request_guid, + int num_blocks) { + assert(block_tables.find(request_guid) != block_tables.end()); + auto &block_table = block_tables[request_guid]; + assert(num_blocks <= block_table.size()); + int num_blocks_allocated = block_table.size(); + for (int i = 0; i < num_blocks; i++) { + block_allocator.free(block_table[num_blocks_allocated - i - 1]); + } + // only keep the first num_blocks_allocated - num_blocks blocks + block_table.erase(block_table.begin() + num_blocks_allocated - num_blocks, + block_table.end()); + block_tables[request_guid] = block_table; + return; } -// int PageManager::get_index_last_block(const RequestGuid& request_guid) const { +// int PageManager::get_index_last_block(const RequestGuid& request_guid) const +// { // const auto& block_table = block_tables.at(request_guid); // return block_table.back.get_block_number(); // } -std::vector PageManager::get_block_table_indices(const RequestGuid& request_guid) const { - std::vector indices; - const auto& it = block_tables.find(request_guid); - if (it == block_tables.end()) { - return indices; - } - const auto& block_table = it->second; - for (const auto& block : block_table) { - indices.push_back(block.get_block_number()); - } +std::vector PageManager::get_block_table_indices( + RequestGuid const &request_guid) const { + std::vector indices; + auto const &it = block_tables.find(request_guid); + if (it == block_tables.end()) { return indices; + } + auto const &block_table = it->second; + for (auto const &block : block_table) { + indices.push_back(block.get_block_number()); + } + return indices; } int PageManager::get_num_total_free_blocks() const { - return block_allocator.get_num_free_blocks(); + return block_allocator.get_num_free_blocks(); +} + +int PageManager::get_num_allocated_blocks( + RequestGuid const &request_guid) const { + auto it = block_tables.find(request_guid); + if (it == block_tables.end()) { + return 0; + } else { + return it->second.size(); + } } -int PageManager::get_num_allocated_blocks(const RequestGuid& request_guid) const { - auto it = block_tables.find(request_guid); - if (it == block_tables.end()) { - return 0; - }else{ - return it->second.size(); +PageManager *PageManager::get_page_manager(FFModel *ff, + int total_kv_cache_size) { + int num_kv_heads = ff->num_kv_heads; + int size_dt = ff->size_dt; + int qkv_dim = ff->qkv_dim; + int num_transformer_layers = ff->num_transformer_layers; + int pipeline_parallelism_degree = ff->config.pipeline_parallelism_degree; + if (page_manager_singleton == nullptr) { + int num_total_blocks = 0; + if (total_kv_cache_size == -1) { + num_total_blocks = (BatchConfig::max_spec_tree_token_num() + + BatchConfig::max_sequence_length() + kPagesize - 1) / + kPagesize * BatchConfig::max_requests_per_batch(); + } else { + num_total_blocks = + total_kv_cache_size * 1024 * 1024 / size_dt / qkv_dim / num_kv_heads / + (num_transformer_layers / pipeline_parallelism_degree) / 2; } + page_manager_singleton = new PageManager(kPagesize, num_total_blocks); + } + return page_manager_singleton; } PageManager *PageManager::get_page_manager() { if (page_manager_singleton == nullptr) { - int num_total_blocks = (BatchConfig::max_spec_tree_token_num() + - BatchConfig::max_sequence_length() + kPagesize - 1) / + int num_total_blocks = + (BatchConfig::max_spec_tree_token_num() + + BatchConfig::max_sequence_length() + kPagesize - 1) / kPagesize * BatchConfig::max_requests_per_batch(); page_manager_singleton = new PageManager(kPagesize, num_total_blocks); } return page_manager_singleton; } - -}; //FlexFlow \ No newline at end of file +}; // namespace FlexFlow \ No newline at end of file diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 83fdc5f1ff..301d41a2a1 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -20,6 +20,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -30,9 +33,6 @@ #include #include #include -#include -#include -#include namespace FlexFlow { @@ -42,9 +42,10 @@ using tokenizers::Tokenizer; Legion::Logger log_req_mgr("RequestManager"); void printStackTrace() { - void *array[10]; - size_t size = backtrace(array, 10); // Get stack frames - backtrace_symbols_fd(array, size, STDERR_FILENO); // Print stack trace to stderr + void *array[10]; + size_t size = backtrace(array, 10); // Get stack frames + backtrace_symbols_fd( + array, size, STDERR_FILENO); // Print stack trace to stderr } bool operator<(std::shared_ptr const &lhs, @@ -220,7 +221,6 @@ int RequestManager::get_max_kv_cache_size() { return max_kv_cache_size; } - void RequestManager::set_decoding_mode(DecodingMode mode) { assert(mode == INCREMENTAL_DECODING || mode == SPECULATIVE_DECODING); decoding_mode = mode; @@ -1175,19 +1175,22 @@ BatchConfig RequestManager::prepare_llm_prefilling_batch() { request->tokens[request->llm_prefill_len + idx]; assert(request->llm_prefill_len + idx < request->tokens.size()); - append_token_to_block(*request, request->tokens[request->llm_prefill_len + idx], true); + append_token_to_block( + *request, request->tokens[request->llm_prefill_len + idx], true); } num_tokens += num_tokens_in_batch; if (num_tokens_in_batch > 0) { bc.num_available_requests++; } - //update related page info in batch config - bc.requestsInfo[request_index].num_kv_pages = get_num_blocks_allocated(*request); + // update related page info in batch config + bc.requestsInfo[request_index].num_kv_pages = + get_num_blocks_allocated(*request); if (bc.requestsInfo[request_index].num_kv_pages == 0) { // turn this request into not available for one round bc.request_available[request_index] = false; } - bc.requestsInfo[request_index].kv_last_page_len = get_len_last_block(*request); + bc.requestsInfo[request_index].kv_last_page_len = + get_len_last_block(*request); bc.requestsInfo[request_index].request_guid = request->guid; } bc.num_tokens = num_tokens; @@ -1577,11 +1580,12 @@ BatchConfig RequestManager::prepare_verify_batch_config() { Request &request = all_requests[guid]; assert(request.status == Request::RUNNING); - //before commit token, reset the pages assigned by cleaning all the tokens - std::vector block_table_before_commit = page_manager->get_block_table_indices(guid); + // before commit token, reset the pages assigned by cleaning all the tokens + std::vector block_table_before_commit = + page_manager->get_block_table_indices(guid); // also need to reset the pages reset_block_table(request); - + int token_offset = request.first_token_offset_in_batch; // 1. Maintain requestsInfo @@ -1602,21 +1606,25 @@ BatchConfig RequestManager::prepare_verify_batch_config() { committed_token_index++) { Request::CommittedToken &committed_token = committed_tokens.at(committed_token_index); - - int idx_to_physical = append_token_to_block(request, committed_token.token_id, true); + + int idx_to_physical = + append_token_to_block(request, committed_token.token_id, true); int idx_from_logical = committed_token.from_index; assert(idx_from_logical >= 0); assert(idx_from_logical / kPagesize < block_table_before_commit.size()); - int idx_from_physical = block_table_before_commit[idx_from_logical / kPagesize] * kPagesize + committed_token.from_index % kPagesize; - + int idx_from_physical = + block_table_before_commit[idx_from_logical / kPagesize] * kPagesize + + committed_token.from_index % kPagesize; new_bc.committed_tokens[new_bc.num_tokens_to_commit].request_index = request_index; - new_bc.committed_tokens[new_bc.num_tokens_to_commit].index_in_kv_cache = idx_from_physical; - new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = idx_to_physical; + new_bc.committed_tokens[new_bc.num_tokens_to_commit].index_in_kv_cache = + idx_from_physical; + new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = + idx_to_physical; new_bc.num_tokens_to_commit++; // also append the token to the block - } + } // Load the tokens on the token tree that are not yet pruned to // BatchConfig.tokensInfo. @@ -1654,9 +1662,11 @@ BatchConfig RequestManager::prepare_verify_batch_config() { new_bc.streamingCacheInfo[request_index] = request.streaming_cache_info; // page attention information - new_bc.requestsInfo[request_index].num_kv_pages = get_num_blocks_allocated(request); + new_bc.requestsInfo[request_index].num_kv_pages = + get_num_blocks_allocated(request); assert(new_bc.requestsInfo[request_index].num_kv_pages > 0); - new_bc.requestsInfo[request_index].kv_last_page_len = get_len_last_block(request); + new_bc.requestsInfo[request_index].kv_last_page_len = + get_len_last_block(request); assert(new_bc.requestsInfo[request_index].kv_last_page_len > 0); new_bc.requestsInfo[request_index].request_guid = request.guid; } @@ -1978,7 +1988,9 @@ BatchConfig::BitMask RequestManager::create_llm_bitmask(RequestGuid guid) { /* --------- Page Attention Related Functions --------- */ int RequestManager::get_num_blocks_allocated(Request &request) const { // needs some assertion - assert(request.blocks.size() == PageManager::get_page_manager()->get_block_table_indices(request.guid).size()); + assert(request.blocks.size() == PageManager::get_page_manager() + ->get_block_table_indices(request.guid) + .size()); return request.blocks.size(); } @@ -1995,57 +2007,68 @@ int RequestManager::get_idx_last_logical_token(Request &request) const { if (request.blocks.empty()) { printf("Error: request.blocks is empty\n"); return -1; - }else{ - return (request.blocks.size() - 1) * kPagesize + request.blocks.back().get_num_tokens() - 1; + } else { + return (request.blocks.size() - 1) * kPagesize + + request.blocks.back().get_num_tokens() - 1; } } int RequestManager::idx_logical_to_physical(Request &request, int idx_logical) { // get physical indices PageManager *page_manager = PageManager::get_page_manager(); - std::vector block_table_indices = page_manager->get_block_table_indices(request.guid); + std::vector block_table_indices = + page_manager->get_block_table_indices(request.guid); if (request.blocks.size() != block_table_indices.size()) { assert(request.blocks.size() == block_table_indices.size()); } - return block_table_indices[idx_logical / kPagesize] * kPagesize + idx_logical % kPagesize; + return block_table_indices[idx_logical / kPagesize] * kPagesize + + idx_logical % kPagesize; } // this will allocate one logical block and one physical block to the request -void RequestManager::_append_block_to_request( - Request &request, bool is_commit) { +void RequestManager::_append_block_to_request(Request &request, + bool is_commit) { PageManager *page_manager = PageManager::get_page_manager(); assert(request.page_last_committed < static_cast(request.blocks.size())); - assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size()); + assert(request.blocks.size() == + page_manager->get_block_table_indices(request.guid).size()); // Append the logical block to the request - // page attention: in this function we need to remember the last logical block number that still contains committed tokens - LogicalTokenBlock block(request.blocks.size(), - kPagesize); + // page attention: in this function we need to remember the last logical block + // number that still contains committed tokens + LogicalTokenBlock block(request.blocks.size(), kPagesize); request.blocks.push_back(block); page_manager->allocate_one_block(request.guid); - std::vector block_table_indices = page_manager->get_block_table_indices(request.guid); - assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size()); + std::vector block_table_indices = + page_manager->get_block_table_indices(request.guid); + assert(request.blocks.size() == + page_manager->get_block_table_indices(request.guid).size()); // update page_id_commit if (is_commit) { request.page_last_committed++; int size_blocks = request.blocks.size(); - assert(request.page_last_committed < static_cast(request.blocks.size())); + assert(request.page_last_committed < + static_cast(request.blocks.size())); } } -//this function is used for appending a token to the last logical block and also the last physical block -//it will return the physical position of this token -int RequestManager::append_token_to_block(Request &request, TokenId token, bool is_commit) { +// this function is used for appending a token to the last logical block and +// also the last physical block it will return the physical position of this +// token +int RequestManager::append_token_to_block(Request &request, + TokenId token, + bool is_commit) { PageManager *page_manager = PageManager::get_page_manager(); - if (request.blocks.empty() || - request.blocks.back().is_full()) { + if (request.blocks.empty() || request.blocks.back().is_full()) { // Append a new logical block _append_block_to_request(request, is_commit); - assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size()); + assert(request.blocks.size() == + page_manager->get_block_table_indices(request.guid).size()); // also allocate one physical page } // insert token to both logical block and physical block request.blocks.back().append_tokens({token}, is_commit); - assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size()); + assert(request.blocks.size() == + page_manager->get_block_table_indices(request.guid).size()); int idx_logical = get_idx_last_logical_token(request); assert(idx_logical >= 0); int idx_physical = idx_logical_to_physical(request, idx_logical); @@ -2053,36 +2076,47 @@ int RequestManager::append_token_to_block(Request &request, TokenId token, bool return idx_physical; } -void RequestManager::reset_block_table(Request &request){ +void RequestManager::reset_block_table(Request &request) { // get the indices of original physical block table for request PageManager *page_manager = PageManager::get_page_manager(); assert(request.page_last_committed < static_cast(request.blocks.size())); - assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size()); - std::vector block_table_indices = page_manager->get_block_table_indices(request.guid); + assert(request.blocks.size() == + page_manager->get_block_table_indices(request.guid).size()); + std::vector block_table_indices = + page_manager->get_block_table_indices(request.guid); // reset the block table according to the request's page_last_commit assert(block_table_indices.size() > request.page_last_committed); - page_manager->free_multiple_blocks(request.guid, block_table_indices.size() - request.page_last_committed - 1); + page_manager->free_multiple_blocks(request.guid, + block_table_indices.size() - + request.page_last_committed - 1); // reset this request's logical block table if (request.page_last_committed < static_cast(request.blocks.size())) { - request.blocks.erase(request.blocks.begin() + request.page_last_committed + 1, request.blocks.end()); + request.blocks.erase(request.blocks.begin() + request.page_last_committed + + 1, + request.blocks.end()); } request.blocks.back().reset_num_spec_tokens(); // the indices of block table should be the same as the number of blocks - std::vector block_table = page_manager->get_block_table_indices(request.guid); + std::vector block_table = + page_manager->get_block_table_indices(request.guid); - assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size()); + assert(request.blocks.size() == + page_manager->get_block_table_indices(request.guid).size()); return; } // debug function void RequestManager::print_num_tokens(Request &request) { PageManager *page_manager = PageManager::get_page_manager(); - std::vector block_table_indices = page_manager->get_block_table_indices(request.guid); + std::vector block_table_indices = + page_manager->get_block_table_indices(request.guid); printf("number of blocks: %d", request.blocks.size()); printf(" number of pages allocated: %d", block_table_indices.size()); printf(" last page length: %d", request.blocks.back().get_num_tokens()); - printf(" last page spec tokens: %d", request.blocks.back().get_num_spec_tokens()); - printf(" last page commit tokens: %d\n", request.blocks.back().get_num_commit_tokens()); + printf(" last page spec tokens: %d", + request.blocks.back().get_num_spec_tokens()); + printf(" last page commit tokens: %d\n", + request.blocks.back().get_num_commit_tokens()); } /* --------- Bitmask Related Functions --------- */ @@ -2490,6 +2524,8 @@ void RequestManager::background_serving_task( Runtime *runtime) { RequestManager *rm = RequestManager::get_request_manager(); FFModel *llm = *(FFModel **)task->args; + printf("start background serving task and llm has %d num_transfor_layers\n", + llm->num_transformer_layers); { // Update FFModel's lg_hlr and lg_ctx to the current // task's runtime and ctx, since all future legion tasks are @@ -2504,6 +2540,8 @@ void RequestManager::background_serving_task( ssm->config.lg_ctx = ctx; } } + // page attention: initalize the page manager here + PageManager::get_page_manager(llm, rm->get_max_kv_cache_size()); if (rm->decoding_mode == INCREMENTAL_DECODING) { // No SSMs: perform incremental decoding rm->serve_decoding(llm); @@ -2759,9 +2797,6 @@ void RequestManager::terminate_background_server() { } } - - - std::string latency_per_request_ms = "\n latency_per_request_ms( "; for (auto const &profiling_info : profiling_requests) { double latency_ms = (profiling_info.second.finish_time - diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index db2718b9b6..b762fedd2c 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -94,29 +94,37 @@ void prepare_inference_params_kernel_h(BatchConfig const *batch_config, qk_indptr_h[0] = 0; int q_lens = 0, qk_lens = 0; int indices_offset = 0, indices_lens = 0; - for (int req_idx = 0, indptr_idx = 0; req_idx < batch_config->max_requests_per_batch(); req_idx++) { + for (int req_idx = 0, indptr_idx = 0; + req_idx < batch_config->max_requests_per_batch(); + req_idx++) { if (batch_config->request_available[req_idx]) { int q_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch; - int kv_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch + - batch_config->requestsInfo[req_idx].first_token_index_in_request; - + int kv_len = + batch_config->requestsInfo[req_idx].num_tokens_in_batch + + batch_config->requestsInfo[req_idx].first_token_index_in_request; + q_lens += q_len; qk_lens += (q_len * kv_len + 7) / 8; indices_offset = indices_lens; indices_lens += (kv_len + kPagesize - 1) / kPagesize; q_indptr_h[indptr_idx + 1] = q_indptr_h[indptr_idx] + q_len; - kv_indptr_h[indptr_idx + 1] = batch_config->requestsInfo[req_idx].num_kv_pages + kv_indptr_h[indptr_idx]; + kv_indptr_h[indptr_idx + 1] = + batch_config->requestsInfo[req_idx].num_kv_pages + + kv_indptr_h[indptr_idx]; assert(kv_indptr_h[indptr_idx] >= 0); - assert(batch_config->requestsInfo[req_idx].num_kv_pages == (kv_len + kPagesize - 1) / kPagesize); + assert(batch_config->requestsInfo[req_idx].num_kv_pages == + (kv_len + kPagesize - 1) / kPagesize); assert(batch_config->requestsInfo[req_idx].kv_last_page_len <= kPagesize); - std::vector kv_indices = pm -> get_block_table_indices(batch_config->requestsInfo[req_idx].request_guid); + std::vector kv_indices = pm->get_block_table_indices( + batch_config->requestsInfo[req_idx].request_guid); assert(kv_indices.size() == (kv_len + kPagesize - 1) / kPagesize); for (int i = indices_offset; i < indices_lens; i++) { kv_indices_h[i] = kv_indices[i - indices_offset]; } qk_indptr_h[indptr_idx + 1] = qk_lens; - kv_last_page_len_h[indptr_idx] = batch_config->requestsInfo[req_idx].kv_last_page_len; + kv_last_page_len_h[indptr_idx] = + batch_config->requestsInfo[req_idx].kv_last_page_len; indptr_idx++; } } @@ -127,11 +135,12 @@ void prepare_inference_params_kernel_h(BatchConfig const *batch_config, sizeof(int32_t) * batch_size * max_num_pages, cudaMemcpyHostToDevice, stream)); - checkCUDA(cudaMemcpyAsync(handle.tree_verify_attention_metadata->kv_last_page_len, - kv_last_page_len_h, - sizeof(int32_t) * batch_size, - cudaMemcpyHostToDevice, - stream)); + checkCUDA( + cudaMemcpyAsync(handle.tree_verify_attention_metadata->kv_last_page_len, + kv_last_page_len_h, + sizeof(int32_t) * batch_size, + cudaMemcpyHostToDevice, + stream)); checkCUDA(cudaMemcpyAsync(handle.tree_verify_attention_metadata->q_indptr, q_indptr_h, sizeof(int32_t) * (batch_size + 1), @@ -675,8 +684,10 @@ void RequestManager::load_batch_config_task( } } else if (batch_config->get_mode() == TREE_VERIFY_MODE) { PageManager *pm = PageManager::get_page_manager(); - static int32_t q_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1], kv_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1]; - static int32_t kv_indices_h[BatchConfig::MAX_NUM_REQUESTS * BatchConfig::MAX_NUM_TOKENS]; + static int32_t q_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1], + kv_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1]; + static int32_t kv_indices_h[BatchConfig::MAX_NUM_REQUESTS * + BatchConfig::MAX_NUM_TOKENS]; static int32_t qk_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1]; static int32_t kv_last_page_len_h[BatchConfig::MAX_NUM_REQUESTS];