Skip to content

Commit

Permalink
fix tokenizer conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
zhihao committed Oct 15, 2024
1 parent dbd4cf1 commit c46ddc0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 49 deletions.
19 changes: 13 additions & 6 deletions python/flexflow/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
MPTConfig,
)
from flexflow.core import *
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from transformers import AutoConfig, AutoModelForCausalLM
from peft import PeftModel, PeftConfig, LoraConfig
from huggingface_hub import HfApi
import torch, shutil, hashlib, json, gc
from typing import Union, List

from huggingface_hub import snapshot_download

class _SupportedModels:
def __init__(self,):
Expand Down Expand Up @@ -349,10 +349,17 @@ def download_hf_tokenizer_if_needed(self):
print(
f"'{self.model_name}' tokenizer needs updating! Downloading tokenizer now..."
)
# Download tokenizer from HuggingFace, or load it from the local folder
hf_tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
# Save tokenizer
hf_tokenizer.save_pretrained(self.tokenizer_path)
# Load/download the tokenizer files
target_tokenizer_files = ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]
if os.path.exists(self.model_name):
hf_tokenizer_path = self.model_name
else:
hf_tokenizer_path = snapshot_download(repo_id=self.model_name, allow_patterns=target_tokenizer_files)
for file in target_tokenizer_files:
src_path = os.path.join(hf_tokenizer_path, file)
dst_path = os.path.join(self.tokenizer_path, file)
if os.path.exists(src_path):
shutil.copy(src_path, dst_path)
print("Done updating HF tokenizer.")
# Save new revision hash to file
with open(ff_revision_file, "w+") as f:
Expand Down
52 changes: 9 additions & 43 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,29 +186,19 @@ void RequestManager::register_tokenizer(ModelType type,
std::filesystem::path tokenizer_folder(path);

if (model_type == ModelType::LLAMA) {
std::filesystem::path tokenizer_model_path;
std::filesystem::path tokenizer_json_path;
if (std::filesystem::is_directory(tokenizer_folder)) {
tokenizer_model_path =
std::filesystem::path(tokenizer_folder) / "tokenizer.model";
tokenizer_json_path =
std::filesystem::path(tokenizer_folder) / "tokenizer.json";
} else {
tokenizer_model_path = tokenizer_folder;
tokenizer_json_path = tokenizer_folder;
}
if (std::filesystem::exists(tokenizer_model_path)) {
// load from tokenizer.model
this->tokenizer_ = Tokenizer::FromBlobSentencePiece(
LoadBytesFromFile(tokenizer_model_path.string()));
} else {
// load from tokenizer.json
std::filesystem::path tokenizer_json_path =
tokenizer_folder / "tokenizer.json";
if (!std::filesystem::exists(tokenizer_json_path)) {
std::cerr << "Failed to open file: " << tokenizer_json_path
<< std::endl;
assert(false);
}
this->tokenizer_ = Tokenizer::FromBlobJSON(
LoadBytesFromFile(tokenizer_json_path.string()));
if (!std::filesystem::exists(tokenizer_json_path)) {
std::cerr << "Failed to open file: " << tokenizer_json_path << std::endl;
assert(false);
}
this->tokenizer_ = Tokenizer::FromBlobJSON(
LoadBytesFromFile(tokenizer_json_path.string()));
} else if (model_type == ModelType::OPT) {
std::filesystem::path vocab_file = tokenizer_folder / "vocab.json";
std::filesystem::path merges_file = tokenizer_folder / "merges.txt";
Expand Down Expand Up @@ -658,12 +648,6 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
bool request_completed = check_inf_req_completion(old_bc, i);
if (request_completed) {
std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically
// removes the BOS token
if (model_type == ModelType::LLAMA &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
{
// update generation result
GenerationResult &gr = request_generation_results[request.guid];
Expand Down Expand Up @@ -1119,12 +1103,6 @@ BeamSearchBatchConfig
request.guid,
request.tokens.size());
std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically
// removes the BOS token
if (model_type == ModelType::LLAMA &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
{
// update generation result
GenerationResult &gr = request_generation_results[request.guid];
Expand Down Expand Up @@ -1262,12 +1240,6 @@ BeamSearchBatchConfig
}

std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically
// removes the BOS token
if (model_type == ModelType::LLAMA &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
log_req_mgr.print("Output: %s", output.c_str());
}

Expand Down Expand Up @@ -1310,12 +1282,6 @@ BeamSearchBatchConfig

// Token Info
std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically removes
// the BOS token
if (model_type == ModelType::LLAMA &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
log_req_mgr.print("Output: %s", output.c_str());
} else {
assert(false);
Expand Down

0 comments on commit c46ddc0

Please sign in to comment.