Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhihao committed Oct 15, 2024
1 parent 2b5a023 commit 98588f2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
1 change: 1 addition & 0 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ class RequestManager {
ModelType model_type;
int bos_token_id;
int eos_token_id;
bool old_llama_tokenizer = false;
std::string output_filepath;
std::queue<Request> pending_infr_request_queue;
std::queue<Request> pending_peft_request_queue;
Expand Down
27 changes: 18 additions & 9 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,23 +189,32 @@ void RequestManager::register_tokenizer(ModelType type,
// try with tokenizer.json first
std::filesystem::path tokenizer_json_path;
if (std::filesystem::is_directory(tokenizer_folder)) {
tokenizer_json_path = std::filesystem::path(tokenizer_folder) / "tokenizer.json";
tokenizer_json_path =
std::filesystem::path(tokenizer_folder) / "tokenizer.json";
} else {
tokenizer_json_path = tokenizer_folder;
}
if (std::filesystem::exists(tokenizer_json_path)) {
old_llama_tokenizer = true;
// load from tokenizer.json
this->tokenizer_ = Tokenizer::FromBlobJSON(LoadBytesFromFile(tokenizer_json_path.string()));
this->tokenizer_ = Tokenizer::FromBlobJSON(
LoadBytesFromFile(tokenizer_json_path.string()));
} else {
// load from tokenizer.model
std::filesystem::path tokenizer_model_path =
tokenizer_folder / "tokenizer.model";
std::filesystem::path tokenizer_model_path;
if (std::filesystem::is_directory(tokenizer_folder)) {
tokenizer_model_path =
std::filesystem::path(tokenizer_folder) / "tokenizer.model";
} else {
tokenizer_model_path = tokenizer_folder;
}
if (!std::filesystem::exists(tokenizer_model_path)) {
std::cerr << "Failed to open file: " << tokenizer_model_path
<< std::endl;
assert(false);
}
this->tokenizer_ = Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(tokenizer_model_path.string()));
this->tokenizer_ = Tokenizer::FromBlobSentencePiece(
LoadBytesFromFile(tokenizer_model_path.string()));
}
} else if (model_type == ModelType::OPT) {
std::filesystem::path vocab_file = tokenizer_folder / "vocab.json";
Expand Down Expand Up @@ -658,7 +667,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically
// removes the BOS token
if (model_type == ModelType::LLAMA &&
if (model_type == ModelType::LLAMA && old_llama_tokenizer &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
Expand Down Expand Up @@ -1119,7 +1128,7 @@ BeamSearchBatchConfig
std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically
// removes the BOS token
if (model_type == ModelType::LLAMA &&
if (model_type == ModelType::LLAMA && old_llama_tokenizer &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
Expand Down Expand Up @@ -1262,7 +1271,7 @@ BeamSearchBatchConfig
std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically
// removes the BOS token
if (model_type == ModelType::LLAMA &&
if (model_type == ModelType::LLAMA && old_llama_tokenizer &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
Expand Down Expand Up @@ -1310,7 +1319,7 @@ BeamSearchBatchConfig
std::string output = this->tokenizer_->Decode(request.tokens);
// Unlike Huggingface, the sentencepiece C++ library automatically removes
// the BOS token
if (model_type == ModelType::LLAMA &&
if (model_type == ModelType::LLAMA && old_llama_tokenizer &&
request.tokens.at(0) == bos_token_id) {
output = "<s> " + output;
}
Expand Down

0 comments on commit 98588f2

Please sign in to comment.