diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index 2b5d307d5c..bfbc2c060d 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -350,7 +350,7 @@ def download_hf_tokenizer_if_needed(self): f"'{self.model_name}' tokenizer needs updating! Downloading tokenizer now..." ) # Load/download the tokenizer files - target_tokenizer_files = ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"] + target_tokenizer_files = ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "vocab.json", "merges.txt"] if os.path.exists(self.model_name): hf_tokenizer_path = self.model_name else: diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index ec1ea6c7eb..fcc936daa7 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -186,19 +186,27 @@ void RequestManager::register_tokenizer(ModelType type, std::filesystem::path tokenizer_folder(path); if (model_type == ModelType::LLAMA) { + // try with tokenizer.json first std::filesystem::path tokenizer_json_path; if (std::filesystem::is_directory(tokenizer_folder)) { - tokenizer_json_path = - std::filesystem::path(tokenizer_folder) / "tokenizer.json"; + tokenizer_json_path = std::filesystem::path(tokenizer_folder) / "tokenizer.json"; } else { tokenizer_json_path = tokenizer_folder; } - if (!std::filesystem::exists(tokenizer_json_path)) { - std::cerr << "Failed to open file: " << tokenizer_json_path << std::endl; - assert(false); + if (std::filesystem::exists(tokenizer_json_path)) { + // load from tokenizer.json + this->tokenizer_ = Tokenizer::FromBlobJSON(LoadBytesFromFile(tokenizer_json_path.string())); + } else { + // load from tokenizer.model + std::filesystem::path tokenizer_model_path = + tokenizer_folder / "tokenizer.model"; + if (!std::filesystem::exists(tokenizer_model_path)) { + std::cerr << "Failed to open file: " << tokenizer_model_path + << std::endl; + assert(false); + } + this->tokenizer_ = Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(tokenizer_model_path.string())); } - this->tokenizer_ = Tokenizer::FromBlobJSON( - LoadBytesFromFile(tokenizer_json_path.string())); } else if (model_type == ModelType::OPT) { std::filesystem::path vocab_file = tokenizer_folder / "vocab.json"; std::filesystem::path merges_file = tokenizer_folder / "merges.txt"; @@ -648,6 +656,12 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, bool request_completed = check_inf_req_completion(old_bc, i); if (request_completed) { std::string output = this->tokenizer_->Decode(request.tokens); + // Unlike Huggingface, the sentencepiece C++ library automatically + // removes the BOS token + if (model_type == ModelType::LLAMA && + request.tokens.at(0) == bos_token_id) { + output = " " + output; + } { // update generation result GenerationResult &gr = request_generation_results[request.guid]; @@ -1103,6 +1117,12 @@ BeamSearchBatchConfig request.guid, request.tokens.size()); std::string output = this->tokenizer_->Decode(request.tokens); + // Unlike Huggingface, the sentencepiece C++ library automatically + // removes the BOS token + if (model_type == ModelType::LLAMA && + request.tokens.at(0) == bos_token_id) { + output = " " + output; + } { // update generation result GenerationResult &gr = request_generation_results[request.guid]; @@ -1240,6 +1260,12 @@ BeamSearchBatchConfig } std::string output = this->tokenizer_->Decode(request.tokens); + // Unlike Huggingface, the sentencepiece C++ library automatically + // removes the BOS token + if (model_type == ModelType::LLAMA && + request.tokens.at(0) == bos_token_id) { + output = " " + output; + } log_req_mgr.print("Output: %s", output.c_str()); } @@ -1282,6 +1308,12 @@ BeamSearchBatchConfig // Token Info std::string output = this->tokenizer_->Decode(request.tokens); + // Unlike Huggingface, the sentencepiece C++ library automatically removes + // the BOS token + if (model_type == ModelType::LLAMA && + request.tokens.at(0) == bos_token_id) { + output = " " + output; + } log_req_mgr.print("Output: %s", output.c_str()); } else { assert(false);