diff --git a/conda/flexflow.yml b/conda/flexflow.yml index 091ba929e4..771b40ecd5 100644 --- a/conda/flexflow.yml +++ b/conda/flexflow.yml @@ -16,9 +16,9 @@ dependencies: - qualname>=0.1.0 - keras_preprocessing>=1.1.2 - numpy>=1.16.0 - - torch>=1.13.1 --index-url https://download.pytorch.org/whl/cpu - - torchaudio>=0.13.1 --index-url https://download.pytorch.org/whl/cpu - - torchvision>=0.14.1 --index-url https://download.pytorch.org/whl/cpu + - torch>=1.13.1 + - torchaudio>=0.13.1 + - torchvision>=0.14.1 - regex - onnx - transformers>=4.31.0 diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h index 52f67d8efb..6501b0658c 100644 --- a/include/flexflow/flexflow_c.h +++ b/include/flexflow/flexflow_c.h @@ -653,6 +653,7 @@ void flexflow_model_generate(flexflow_model_t handle_, char **output_texts, int *max_lengths, int *max_new_tokens_, + bool *add_special_tokens_, flexflow_peft_model_id_t *peft_model_ids, char const **dataset_filepaths, int *training_steps, @@ -1019,6 +1020,9 @@ void flexflow_request_manager_set_max_spec_tree_token_num( void flexflow_request_manager_set_max_sequence_length( flexflow_request_manager_t handle_, int max_seq_length); +int flexflow_request_manager_get_max_sequence_length( + flexflow_request_manager_t handle_); + void flexflow_request_manager_set_enable_peft_finetuning( flexflow_request_manager_t handle_, bool enable_peft_finetuning_); @@ -1026,7 +1030,8 @@ void flexflow_request_manager_register_tokenizer( flexflow_request_manager_t handle_, enum ModelType model_type, int bos_token_id, - int eos_token_id, + int num_eos_token_ids, + int *eos_token_ids, char const *tokenizer_filepath); void flexflow_request_manager_register_output_filepath( diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 94bfc74244..d62b610f3d 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -69,6 +69,7 @@ struct Request { PEFTModelID peft_model_id = PEFTModelID::NO_ID; int max_length = -1; int max_new_tokens = -1; + bool add_special_tokens = true; int initial_len; int ssm_cache_size = 0; int llm_cache_size = 0; @@ -146,7 +147,7 @@ class RequestManager { int register_ssm_model(FFModel *model); void register_tokenizer(ModelType model_type, int bos_token_id, - int eos_token_id, + std::vector eos_token_ids, std::string const &path); void register_output_filepath(std::string const &); void initBitMask(BatchConfig::BitMask &bitmask, int initLength); @@ -178,6 +179,7 @@ class RequestManager { bool is_request_completed(RequestGuid const &guid); void trigger_request_completion_future(RequestGuid const &guid); // Methods for preparing next batches + bool is_eos_token(int token_id); bool check_inf_req_completion(BatchConfig const &old_bc, int i); void check_batch(BatchConfig const &old_bc, BatchConfig const &new_bc); BatchConfig prepare_next_batch(BatchConfig const &bc, @@ -301,7 +303,7 @@ class RequestManager { bool verbose; ModelType model_type; int bos_token_id; - int eos_token_id; + std::vector eos_token_ids; bool old_llama_tokenizer = false; std::string output_filepath; std::queue pending_infr_request_queue; diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index f8e16f24fa..f148d440e2 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -199,9 +199,20 @@ void FlexFlow::top_level_task(Task const *task, int bos_token_id = model_config.find("bos_token_id") == model_config.end() ? -1 : (int)model_config.at("bos_token_id"); - int eos_token_id = model_config.find("eos_token_id") == model_config.end() - ? -1 - : (int)model_config.at("eos_token_id"); + // parse eos token id, which can be either a single integer or an array of + // integers. Convert to std::vector + std::vector eos_token_ids; + if (model_config.find("eos_token_id") != model_config.end()) { + if (model_config["eos_token_id"].is_array()) { + for (auto &eos_token_id : model_config["eos_token_id"]) { + eos_token_ids.push_back(eos_token_id); + } + } else { + eos_token_ids.push_back(model_config["eos_token_id"]); + } + } else { + eos_token_ids.push_back(-1); + } assert(model_type != ModelType::UNKNOWN && "Invalid LLM model type passed (or no type was passed)."); @@ -212,7 +223,7 @@ void FlexFlow::top_level_task(Task const *task, rm->set_max_tokens_per_batch(max_tokens_per_batch); rm->set_max_sequence_length(max_sequence_length); rm->register_tokenizer( - model_type, bos_token_id, eos_token_id, tokenizer_filepath); + model_type, bos_token_id, eos_token_ids, tokenizer_filepath); rm->register_output_filepath(file_paths.output_file_path); FFModel model(ffconfig, ffconfig.cpu_offload); diff --git a/inference/peft/peft.cc b/inference/peft/peft.cc index 14fc653eba..0ab0b62ee8 100644 --- a/inference/peft/peft.cc +++ b/inference/peft/peft.cc @@ -229,9 +229,20 @@ void FlexFlow::top_level_task(Task const *task, int bos_token_id = model_config.find("bos_token_id") == model_config.end() ? -1 : (int)model_config.at("bos_token_id"); - int eos_token_id = model_config.find("eos_token_id") == model_config.end() - ? -1 - : (int)model_config.at("eos_token_id"); + // parse eos token id, which can be either a single integer or an array of + // integers. Convert to std::vector + std::vector eos_token_ids; + if (model_config.find("eos_token_id") != model_config.end()) { + if (model_config["eos_token_id"].is_array()) { + for (auto &eos_token_id : model_config["eos_token_id"]) { + eos_token_ids.push_back(eos_token_id); + } + } else { + eos_token_ids.push_back(model_config["eos_token_id"]); + } + } else { + eos_token_ids.push_back(-1); + } assert(model_type != ModelType::UNKNOWN && "Invalid LLM model type passed (or no type was passed)."); @@ -267,7 +278,7 @@ void FlexFlow::top_level_task(Task const *task, rm->set_max_tokens_per_batch(max_tokens_per_batch); rm->set_max_sequence_length(max_sequence_length); rm->register_tokenizer( - model_type, bos_token_id, eos_token_id, tokenizer_filepath); + model_type, bos_token_id, eos_token_ids, tokenizer_filepath); rm->register_output_filepath(file_paths.output_file_path); rm->set_enable_peft_finetuning(enable_peft_finetuning); diff --git a/inference/peft/peft_bwd_benchmark.cc b/inference/peft/peft_bwd_benchmark.cc index df9a1e35db..85e97ec4e8 100644 --- a/inference/peft/peft_bwd_benchmark.cc +++ b/inference/peft/peft_bwd_benchmark.cc @@ -230,9 +230,20 @@ void FlexFlow::top_level_task(Task const *task, int bos_token_id = model_config.find("bos_token_id") == model_config.end() ? -1 : (int)model_config.at("bos_token_id"); - int eos_token_id = model_config.find("eos_token_id") == model_config.end() - ? -1 - : (int)model_config.at("eos_token_id"); + // parse eos token id, which can be either a single integer or an array of + // integers. Convert to std::vector + std::vector eos_token_ids; + if (model_config.find("eos_token_id") != model_config.end()) { + if (model_config["eos_token_id"].is_array()) { + for (auto &eos_token_id : model_config["eos_token_id"]) { + eos_token_ids.push_back(eos_token_id); + } + } else { + eos_token_ids.push_back(model_config["eos_token_id"]); + } + } else { + eos_token_ids.push_back(-1); + } assert(model_type != ModelType::UNKNOWN && "Invalid LLM model type passed (or no type was passed)."); @@ -251,7 +262,7 @@ void FlexFlow::top_level_task(Task const *task, rm->set_max_tokens_per_batch(max_tokens_per_batch); rm->set_max_sequence_length(max_sequence_length); rm->register_tokenizer( - model_type, bos_token_id, eos_token_id, tokenizer_filepath); + model_type, bos_token_id, eos_token_ids, tokenizer_filepath); rm->register_output_filepath(file_paths.output_file_path); rm->set_enable_peft_finetuning(enable_peft_finetuning); diff --git a/inference/peft/peft_fwd_benchmark.cc b/inference/peft/peft_fwd_benchmark.cc index 9b020f5954..87322a42dd 100644 --- a/inference/peft/peft_fwd_benchmark.cc +++ b/inference/peft/peft_fwd_benchmark.cc @@ -230,9 +230,20 @@ void FlexFlow::top_level_task(Task const *task, int bos_token_id = model_config.find("bos_token_id") == model_config.end() ? -1 : (int)model_config.at("bos_token_id"); - int eos_token_id = model_config.find("eos_token_id") == model_config.end() - ? -1 - : (int)model_config.at("eos_token_id"); + // parse eos token id, which can be either a single integer or an array of + // integers. Convert to std::vector + std::vector eos_token_ids; + if (model_config.find("eos_token_id") != model_config.end()) { + if (model_config["eos_token_id"].is_array()) { + for (auto &eos_token_id : model_config["eos_token_id"]) { + eos_token_ids.push_back(eos_token_id); + } + } else { + eos_token_ids.push_back(model_config["eos_token_id"]); + } + } else { + eos_token_ids.push_back(-1); + } assert(model_type != ModelType::UNKNOWN && "Invalid LLM model type passed (or no type was passed)."); @@ -251,7 +262,7 @@ void FlexFlow::top_level_task(Task const *task, rm->set_max_tokens_per_batch(max_tokens_per_batch); rm->set_max_sequence_length(max_sequence_length); rm->register_tokenizer( - model_type, bos_token_id, eos_token_id, tokenizer_filepath); + model_type, bos_token_id, eos_token_ids, tokenizer_filepath); rm->register_output_filepath(file_paths.output_file_path); rm->set_enable_peft_finetuning(enable_peft_finetuning); diff --git a/inference/peft/req_rate_benchmark.cc b/inference/peft/req_rate_benchmark.cc index cde3b1c02e..ffa77478e1 100644 --- a/inference/peft/req_rate_benchmark.cc +++ b/inference/peft/req_rate_benchmark.cc @@ -292,9 +292,20 @@ void FlexFlow::top_level_task(Task const *task, int bos_token_id = model_config.find("bos_token_id") == model_config.end() ? -1 : (int)model_config.at("bos_token_id"); - int eos_token_id = model_config.find("eos_token_id") == model_config.end() - ? -1 - : (int)model_config.at("eos_token_id"); + // parse eos token id, which can be either a single integer or an array of + // integers. Convert to std::vector + std::vector eos_token_ids; + if (model_config.find("eos_token_id") != model_config.end()) { + if (model_config["eos_token_id"].is_array()) { + for (auto &eos_token_id : model_config["eos_token_id"]) { + eos_token_ids.push_back(eos_token_id); + } + } else { + eos_token_ids.push_back(model_config["eos_token_id"]); + } + } else { + eos_token_ids.push_back(-1); + } assert(model_type != ModelType::UNKNOWN && "Invalid LLM model type passed (or no type was passed)."); @@ -313,7 +324,7 @@ void FlexFlow::top_level_task(Task const *task, rm->set_max_tokens_per_batch(max_tokens_per_batch); rm->set_max_sequence_length(max_sequence_length); rm->register_tokenizer( - model_type, bos_token_id, eos_token_id, tokenizer_filepath); + model_type, bos_token_id, eos_token_ids, tokenizer_filepath); rm->register_output_filepath(file_paths.output_file_path); rm->set_enable_peft_finetuning(enable_peft_finetuning); diff --git a/inference/python/chat.py b/inference/python/chat.py new file mode 100644 index 0000000000..13ece116a6 --- /dev/null +++ b/inference/python/chat.py @@ -0,0 +1,100 @@ +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flexflow.serve as ff +import argparse, json, os +from types import SimpleNamespace + + +def get_configs(): + # Define sample configs + ff_init_configs = { + # required parameters + "num_gpus": 1, + "memory_per_gpu": 30000, + "zero_copy_memory_per_node": 60000, + # optional parameters + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 1, + "pipeline_parallelism_degree": 1, + "offload": False, + "offload_reserve_space_size": 8 * 1024, # 8GB + "use_4bit_quantization": False, + "use_8bit_quantization": False, + "enable_peft": False, + "peft_activation_reserve_space_size": 1024, # 1GB + "peft_weight_reserve_space_size": 1024, # 1GB + "profiling": False, + "benchmarking": False, + "inference_debugging": False, + "fusion": True, + } + llm_configs = { + # required parameters + "llm_model": "meta-llama/Meta-Llama-3-8B-Instruct", + # optional parameters + "cache_path": os.environ.get("FF_CACHE_PATH", ""), + "refresh_cache": False, + "full_precision": False, + } + # Merge dictionaries + ff_init_configs.update(llm_configs) + return ff_init_configs + + +def main(): + configs_dict = get_configs() + configs = SimpleNamespace(**configs_dict) + + # Initialize the FlexFlow runtime. ff.init() takes a dictionary or the path to a JSON file with the configs + ff.init(configs_dict) + + # Create the FlexFlow LLM + ff_data_type = ( + ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + ) + llm = ff.LLM( + configs.llm_model, + data_type=ff_data_type, + cache_path=configs.cache_path, + refresh_cache=configs.refresh_cache, + ) + + # Compile the LLM for inference and load the weights into memory + generation_config = ff.GenerationConfig( + do_sample=False, temperature=0.9, topp=0.8, topk=1 + ) + llm.compile( + generation_config, + max_requests_per_batch=1, + max_seq_length=2048, + max_tokens_per_batch=256, + ) + + llm.start_server() + + messages=[ + {"role": "system", "content": "You are a helpful an honest programming assistant."}, + {"role": "user", "content": "Is Rust better than Python?"}, + ] + llm.generate(messages, max_new_tokens=256) + + llm.stop_server() + + +if __name__ == "__main__": + print("flexflow inference example (incremental decoding)") + main() diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 134ae70c4a..7ec3cf61f5 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -47,7 +47,8 @@ struct ModelMeta { std::string llm_weights_path; std::string llm_model_config_path; - int bos_token_id, eos_token_id; + int bos_token_id; + std::vector eos_token_ids; std::vector ssm_model_types; std::vector ssm_model_config_paths; @@ -191,10 +192,20 @@ void get_model_meta(FilePaths &file_paths, llm_model_config.find("bos_token_id") == llm_model_config.end() ? -1 : (int)llm_model_config.at("bos_token_id"); - model_metadata.eos_token_id = - llm_model_config.find("eos_token_id") == llm_model_config.end() - ? -1 - : (int)llm_model_config.at("eos_token_id"); + // parse eos token id, which can be either a single integer or an array of + // integers. Convert to std::vector + std::vector eos_token_ids; + if (llm_model_config.find("eos_token_id") != llm_model_config.end()) { + if (llm_model_config["eos_token_id"].is_array()) { + for (auto &eos_token_id : llm_model_config["eos_token_id"]) { + model_metadata.eos_token_ids.push_back(eos_token_id); + } + } else { + model_metadata.eos_token_ids.push_back(llm_model_config["eos_token_id"]); + } + } else { + model_metadata.eos_token_ids.push_back(-1); + } for (auto ssm_model_name : model_metadata.model_names.ssm_model_names) { std::string ssm_config_path = join_path({file_paths.cache_folder_path, @@ -241,15 +252,15 @@ void get_model_meta(FilePaths &file_paths, ssm_model_config.find("bos_token_id") == ssm_model_config.end() ? -1 : (int)ssm_model_config.at("bos_token_id"); - int ssm_eos_id = - ssm_model_config.find("eos_token_id") == ssm_model_config.end() - ? -1 - : (int)ssm_model_config.at("eos_token_id"); - if (ssm_bos_id != model_metadata.bos_token_id || - ssm_eos_id != model_metadata.eos_token_id) { - printf("Warning: bos/eos token id mismatch between LLM and one of the " - "SSMs!\n"); - } + // int ssm_eos_id = + // ssm_model_config.find("eos_token_id") == ssm_model_config.end() + // ? -1 + // : (int)ssm_model_config.at("eos_token_id"); + // if (ssm_bos_id != model_metadata.bos_token_id || + // ssm_eos_id != model_metadata.eos_token_id) { + // printf("Warning: bos/eos token id mismatch between LLM and one of the " + // "SSMs!\n"); + // } model_metadata.ssm_model_types.push_back(ssm_model_type); model_metadata.ssm_model_config_paths.push_back(ssm_config_path); model_metadata.ssm_model_weights_paths.push_back(ssm_weights_path); @@ -310,7 +321,7 @@ void FlexFlow::top_level_task(Task const *task, rm->set_max_sequence_length(max_sequence_length); rm->register_tokenizer(model_metadata.llm_model_type, model_metadata.bos_token_id, - model_metadata.eos_token_id, + model_metadata.eos_token_ids, model_metadata.llm_tokenizer_path); rm->register_output_filepath(file_paths.output_file_path); diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index e2240f0b4f..59e62ea023 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -1588,7 +1588,12 @@ def register_tokenizer( c_model_type = enum_to_int(ModelType, model_type) c_tokenizer_filepath = get_c_name(tokenizer_filepath) return ffc().flexflow_request_manager_register_tokenizer( - self.handle, c_model_type, bos_token_id, eos_token_id, c_tokenizer_filepath + self.handle, + c_model_type, + bos_token_id, + len(eos_token_id), + eos_token_id, + c_tokenizer_filepath, ) def register_output_filepath(self, output_filepath): @@ -1622,6 +1627,9 @@ def set_max_sequence_length(self, max_length): self.handle, max_length ) + def get_max_sequence_length(self): + return ffc().flexflow_request_manager_get_max_sequence_length(self.handle) + def set_enable_peft_finetuning(self, enable_peft_finetuning): return ffc().flexflow_request_manager_set_enable_peft_finetuning( self.handle, enable_peft_finetuning @@ -2060,6 +2068,7 @@ class Request: prompt: Optional[str] = None max_length: int = -1 max_new_tokens: int = -1 + add_special_tokens: bool = True peft_model_id: Optional[PEFTModelID] = None dataset_filepath: Optional[str] = None max_training_steps: int = 1 @@ -4652,91 +4661,6 @@ def get_output_tensor(self, ffmodel, data_type): assert ret_val == True return np_array - def _estimate_max_num_tokens( - max_length: int, max_new_tokens: int, prompt: Optional[str] - ): - if prompt is None: - assert max_new_tokens == -1 - return ( - math.ceil(max_new_tokens + len(prompt.split()) * 1.5) - if max_new_tokens != -1 - else max_length - ) - - def _estimate_max_num_chars( - max_length: int, max_new_tokens: int, prompt: Optional[str] - ): - return ( - 5 * FFModel._estimate_max_num_tokens(max_length, max_new_tokens, prompt) - + 100 - ) - - # deprecated - def generate_inf_only( - self, - prompt_list: List[str], - max_length: int, - max_new_tokens: int, - ): - if max_length != -1 and max_new_tokens != -1: - raise ValueError( - f"Both `max_new_tokens` (={max_new_tokens}) and `max_length`(={max_length}) seem to have been set." - ) - if max_length == -1 and max_new_tokens == -1: - raise ValueError( - f"Both `max_new_tokens` (={max_new_tokens}) and `max_length`(={max_length}) were left unset." - ) - assert isinstance(prompt_list, list) - c_input_texts = [get_c_name(prompt) for prompt in prompt_list] - c_output_texts = [ - ffi.new( - "char[]", - FFModel._estimate_max_num_chars(max_length, max_new_tokens, prompt), - ) - for prompt in prompt_list - ] - c_output_length_and_tokens = [ - ffi.new( - "int[]", - FFModel._estimate_max_num_tokens(max_length, max_new_tokens, prompt) - + 100, - ) - for prompt in prompt_list - ] - c_request_types = [ - enum_to_int(RequestType, RequestType.REQ_INFERENCE) for _ in prompt_list - ] - max_lengths = [max_length for _ in prompt_list] - max_new_tokens_ = [max_new_tokens for _ in prompt_list] - peft_model_ids = [PEFTModelID.no_id_handle() for _ in prompt_list] - dataset_filepaths = [ffi.NULL for _ in prompt_list] - training_steps = [0 for _ in prompt_list] - num_finetuning_losses = ffi.new("int *") - c_finetuning_losses = ffi.new("float[]", 0) - ffc().flexflow_model_generate( - self.handle, - len(prompt_list), - c_request_types, - c_input_texts, - c_output_texts, - max_lengths, - max_new_tokens_, - peft_model_ids, - dataset_filepaths, - training_steps, - c_output_length_and_tokens, - num_finetuning_losses, - c_finetuning_losses, - ) - from flexflow.serve import GenerationResult - - return [ - GenerationResult( - text=ffi.string(c_output_text), tokens=[], finetuning_losses=[] - ) - for c_output_text in c_output_texts - ] - def generate(self, requests_list: List[Request]): assert isinstance(requests_list, list) for request in requests_list: @@ -4756,37 +4680,27 @@ def generate(self, requests_list: List[Request]): raise ValueError( f"Finetuning requests should not have `max_new_tokens` set." ) + max_sequence_length = RequestManager().get_max_sequence_length() c_input_texts = [ get_c_name(request.prompt) for request in requests_list ] # entry will be None for finetuning requests c_output_texts = [ ( - ffi.new( - "char[]", - FFModel._estimate_max_num_chars( - request.max_length, request.max_new_tokens, request.prompt - ), - ) + ffi.new("char[]", max_sequence_length * 5) if request.req_type == RequestType.REQ_INFERENCE else ffi.NULL ) for request in requests_list ] c_output_length_and_tokens = [ - ffi.new( - "int[]", - FFModel._estimate_max_num_tokens( - request.max_length, request.max_new_tokens, request.prompt - ) - + 100, - ) - for request in requests_list + ffi.new("int[]", max_sequence_length + 100) for request in requests_list ] c_request_types = [ enum_to_int(RequestType, request.req_type) for request in requests_list ] max_lengths = [request.max_length for request in requests_list] max_new_tokens_ = [request.max_new_tokens for request in requests_list] + add_special_tokens_ = [request.add_special_tokens for request in requests_list] peft_model_ids = [ ( @@ -4813,6 +4727,7 @@ def generate(self, requests_list: List[Request]): c_output_texts, max_lengths, max_new_tokens_, + add_special_tokens_, peft_model_ids, dataset_filepaths, training_steps, diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index c8540a6ed3..e4248a2fc1 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -27,7 +27,7 @@ MPTConfig, ) from flexflow.core import * -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from peft import PeftModel, PeftConfig, LoraConfig from huggingface_hub import HfApi import torch, shutil, hashlib, json, gc @@ -104,6 +104,7 @@ def __init__( self.output_file = output_file self.rm = None self.pefts = {} + self.tokenizer=None def __del__(self): # Stop the background server before deleting the object @@ -499,6 +500,10 @@ def compile( eos_token_id = ( -1 if self.hf_config.eos_token_id is None else self.hf_config.eos_token_id ) + if type(eos_token_id) == int: + eos_token_id = [eos_token_id] + elif type(eos_token_id) != list: + raise ValueError("eos_token_id must be an integer or a list of integers") self.rm.register_tokenizer( self.model_type, bos_token_id, eos_token_id, self.tokenizer_path ) @@ -548,9 +553,29 @@ def _generate(self, requests: List[Request]): ) return self.model.ffmodel.generate(requests) + def __chat2prompt(self, messages: List[dict]): + """Convert a list of messages to a single prompt string + + :param messages: The list of messages to convert + :type messages: List[dict] + :return: The prompt string + :rtype: str + """ + # ensure that each element is a dictionary, containing the "role" and "content" keys + for message in messages: + if type(message) != dict or "role" not in message or "content" not in message: + raise ValueError( + "Each element in the list must be a dictionary with the keys 'role' and 'content'" + ) + if self.tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + if self.tokenizer.chat_template is None: + raise ValueError(f"Model {self.model_name} does not support chat completion") + return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + def generate( self, - requests_or_prompts: Union[str, List[str], Request, List[Request]], + requests_or_prompts: Union[str, List[str], List[dict], Request, List[Request]], max_length: int = -1, max_new_tokens: int = -1, ): @@ -591,7 +616,30 @@ def generate( for req in requests_or_prompts ] return self._generate(requests) - else: + elif type(requests_or_prompts[0]) == dict: + prompt = self.__chat2prompt(requests_or_prompts) + request = Request( + req_type=RequestType.REQ_INFERENCE, + prompt=prompt, + max_length=max_length, + max_new_tokens=max_new_tokens, + add_special_tokens=False, + ) + return self._generate([request]) + elif type(requests_or_prompts[0]) == list: + prompts = [self.__chat2prompt(messages) for messages in requests_or_prompts] + requests = [ + Request( + req_type=RequestType.REQ_INFERENCE, + prompt=prompt, + max_length=max_length, + max_new_tokens=max_new_tokens, + add_special_tokens=False, + ) + for prompt in prompts + ] + return self._generate(requests) + elif type(requests_or_prompts[0]) == Request: print(requests_or_prompts) return self._generate(requests_or_prompts) else: diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index bfa60a6d54..da90c586e3 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1685,6 +1685,7 @@ void flexflow_model_generate(flexflow_model_t handle_, char **output_texts, int *max_lengths, int *max_new_tokens_, + bool *add_special_tokens_, flexflow_peft_model_id_t *peft_model_ids, char const **dataset_filepaths, int *training_steps, @@ -1701,22 +1702,25 @@ void flexflow_model_generate(flexflow_model_t handle_, inference_req.prompt = text_str; inference_req.max_length = max_lengths[i]; inference_req.max_new_tokens = max_new_tokens_[i]; + inference_req.add_special_tokens = add_special_tokens_[i]; PEFTModelID *peft_model_id = FFCObjectWrapper::unwrap(peft_model_ids[i]); if (peft_model_id != nullptr) { inference_req.peft_model_id = *peft_model_id; } requests.push_back(inference_req); - DEBUG_PRINT("[Model] generate[%d] %p %s %i %i", + DEBUG_PRINT("[Model] generate[%d] %p %s %i %i %i", i, handle, text_str.c_str(), max_lengths[i], - max_new_tokens_[i]); + max_new_tokens_[i], + add_special_tokens_[i]); } else if (request_types[i] == RequestType::REQ_FINETUNING) { Request fine_tuning_req; fine_tuning_req.req_type = RequestType::REQ_FINETUNING; fine_tuning_req.max_length = max_lengths[i]; fine_tuning_req.max_new_tokens = max_new_tokens_[i]; + fine_tuning_req.add_special_tokens = add_special_tokens_[i]; PEFTModelID *peft_model_id = FFCObjectWrapper::unwrap(peft_model_ids[i]); if (peft_model_id != nullptr) { fine_tuning_req.peft_model_id = *peft_model_id; @@ -1725,12 +1729,13 @@ void flexflow_model_generate(flexflow_model_t handle_, fine_tuning_req.dataset_filepath = dataset_fp; fine_tuning_req.max_training_steps = training_steps[i]; requests.push_back(fine_tuning_req); - DEBUG_PRINT("[Model] finetune[%d] %p %s %i %i %i", + DEBUG_PRINT("[Model] finetune[%d] %p %s %i %i %i %i", i, handle, dataset_fp.c_str(), max_lengths[i], - max_new_tokens[i], + max_new_tokens_[i], + add_special_tokens_[i], training_steps[i]); } else { assert(false && "Unknown request type"); @@ -2754,6 +2759,12 @@ void flexflow_request_manager_set_max_sequence_length( DEBUG_PRINT("[RequestManager] set max_sequence_length %d", max_seq_length); } +int flexflow_request_manager_get_max_sequence_length( + flexflow_request_manager_t handle_) { + RequestManager *handle = FFCObjectWrapper::unwrap(handle_); + return handle->get_max_sequence_length(); +} + void flexflow_request_manager_set_enable_peft_finetuning( flexflow_request_manager_t handle_, bool enable_peft_finetuning_) { RequestManager *handle = FFCObjectWrapper::unwrap(handle_); @@ -2766,14 +2777,19 @@ void flexflow_request_manager_register_tokenizer( flexflow_request_manager_t handle_, enum ModelType model_type, int bos_token_id, - int eos_token_id, + int num_eos_token_ids, + int *eos_token_ids, char const *tokenizer_filepath) { RequestManager *handle = FFCObjectWrapper::unwrap(handle_); assert(tokenizer_filepath != nullptr && "Cannot convert nullptr char * to std::string"); std::string const tokenizer_filepath_str(tokenizer_filepath); + std::vector eos_token_ids_vec; + for (int i = 0; i < num_eos_token_ids; i++) { + eos_token_ids_vec.push_back(eos_token_ids[i]); + } handle->register_tokenizer( - model_type, bos_token_id, eos_token_id, tokenizer_filepath_str); + model_type, bos_token_id, eos_token_ids_vec, tokenizer_filepath_str); DEBUG_PRINT( "[RequestManager] register tokenizer %p %s", handle, tokenizer_filepath); } diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 5fbee65e6d..193abbb455 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -56,6 +56,7 @@ std::ostream &operator<<(std::ostream &os, Request const &req) { os << " peft_model_id: " << req.peft_model_id << "\n"; os << " max_length: " << req.max_length << "\n"; os << " max_new_tokens: " << req.max_new_tokens << "\n"; + os << " add_special_tokens: " << req.add_special_tokens << "\n"; os << " initial_len: " << req.initial_len << "\n"; os << " ssm_cache_size: " << req.ssm_cache_size << "\n"; os << " llm_cache_size: " << req.llm_cache_size << "\n"; @@ -178,11 +179,11 @@ void RequestManager::set_inference_finished(bool finished) { void RequestManager::register_tokenizer(ModelType type, int bos_token_id, - int eos_token_id, + std::vector eos_token_ids, std::string const &path) { this->model_type = type; this->bos_token_id = bos_token_id; - this->eos_token_id = eos_token_id; + this->eos_token_ids = eos_token_ids; std::filesystem::path tokenizer_folder(path); if (model_type == ModelType::LLAMA) { @@ -271,6 +272,7 @@ RequestManager::RequestGuid request.guid = next_available_guid++; request.max_length = request_.max_length; request.max_new_tokens = request_.max_new_tokens; + request.add_special_tokens = request_.add_special_tokens; // both unset if (request.max_length == -1 && request.max_new_tokens == -1) { request.max_length = get_max_sequence_length() - 1; @@ -285,7 +287,8 @@ RequestManager::RequestGuid } request.peft_model_id = request_.peft_model_id; request.warmup = request_.warmup; - if (bos_token_id >= 0 && model_type != ModelType::FALCON) { + if (bos_token_id >= 0 && model_type != ModelType::FALCON && + request.add_special_tokens) { request.tokens.push_back(bos_token_id); } if (request_.benchmarking_tokens >= 0) { @@ -378,6 +381,7 @@ RequestManager::RequestGuid request.initial_len = 0; request.max_length = request_.max_length; request.max_new_tokens = request_.max_new_tokens; + request.add_special_tokens = request_.add_special_tokens; if (request.max_new_tokens != -1) { std::cerr << "Error: max_new_tokens is not allowed for PEFT finetuning requests" @@ -402,7 +406,8 @@ RequestManager::RequestGuid request.benchmarking_tokens = request_.benchmarking_tokens; std::vector input_tokens; std::vector output_tokens; - bool bos_added = (bos_token_id >= 0 && model_type != ModelType::FALCON); + bool bos_added = (bos_token_id >= 0 && request.add_special_tokens && + model_type != ModelType::FALCON); if (bos_added) { input_tokens.push_back(bos_token_id); } @@ -424,7 +429,8 @@ RequestManager::RequestGuid std::string output_text(""); std::vector input_tokens; input_tokens = this->tokenizer_->Encode(text); - if (bos_token_id >= 0 && model_type != ModelType::FALCON) { + if (bos_token_id >= 0 && model_type != ModelType::FALCON && + request.add_special_tokens) { input_tokens.insert(input_tokens.begin(), bos_token_id); } std::vector output_tokens = @@ -557,6 +563,15 @@ BatchConfig RequestManager::prepare_next_batch_task( return rm->prepare_next_batch(*bc, result); } +bool RequestManager::is_eos_token(int token_id) { + for (int eos_token : eos_token_ids) { + if (token_id == eos_token) { + return true; + } + } + return false; +} + bool RequestManager::check_inf_req_completion(BatchConfig const &old_bc, int i) { Request &request = all_requests[old_bc.requestsInfo[i].request_guid]; @@ -564,7 +579,7 @@ bool RequestManager::check_inf_req_completion(BatchConfig const &old_bc, // printf("model_type = %d\n", this->model_type); if (request.tokens.size() >= old_bc.requestsInfo[i].max_length) { request_completed = true; - } else if (request.tokens.back() == eos_token_id) { + } else if (is_eos_token(request.tokens.back())) { // Encounter EOS token id request_completed = true; } @@ -673,6 +688,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, // Unlike Huggingface, the sentencepiece C++ library automatically // removes the BOS token if (model_type == ModelType::LLAMA && old_llama_tokenizer && + request.add_special_tokens && request.tokens.at(0) == bos_token_id) { output = " " + output; } @@ -1134,6 +1150,7 @@ BeamSearchBatchConfig // Unlike Huggingface, the sentencepiece C++ library automatically // removes the BOS token if (model_type == ModelType::LLAMA && old_llama_tokenizer && + request.add_special_tokens && request.tokens.at(0) == bos_token_id) { output = " " + output; } @@ -1277,6 +1294,7 @@ BeamSearchBatchConfig // Unlike Huggingface, the sentencepiece C++ library automatically // removes the BOS token if (model_type == ModelType::LLAMA && old_llama_tokenizer && + request.add_special_tokens && request.tokens.at(0) == bos_token_id) { output = " " + output; } @@ -1325,7 +1343,7 @@ BeamSearchBatchConfig // Unlike Huggingface, the sentencepiece C++ library automatically removes // the BOS token if (model_type == ModelType::LLAMA && old_llama_tokenizer && - request.tokens.at(0) == bos_token_id) { + request.add_special_tokens && request.tokens.at(0) == bos_token_id) { output = " " + output; } log_req_mgr.print("Output: %s", output.c_str());