Skip to content

Commit

Permalink
py接口中增加一个tokenizer类
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 21, 2024
1 parent 6c1f680 commit fd7eb1a
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 53 deletions.
2 changes: 2 additions & 0 deletions include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ namespace fastllm {
std::unique_ptr<basellm> CreateLLMModelFromHF(const std::string &modelPath,
DataType linearDataType,
int groupCnt = -1);

std::unique_ptr<basellm> CreateLLMTokenizerFromHF(const std::string &modelPath);
}

#endif //FASTLLM_MODEL_H
117 changes: 67 additions & 50 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,72 @@ namespace fastllm {
}
}

void LoadLLMTokenizerFromHFToModel(const std::string &path, basellm *model) {
std::string error;
std::string tokenizerConfigFile = path + "tokenizer_config.json";
auto tokenizerConfig = json11::Json::parse(ReadAllFile(tokenizerConfigFile), error);
model->weight.tokenizer.SetTokenizerConfig(tokenizerConfig);
std::string tokenizerClass = tokenizerConfig["tokenizer_class"].string_value();
if (tokenizerClass == "PreTrainedTokenizerFast" || tokenizerClass == "Qwen2Tokenizer") {
// PreTrainedTokenizerFast
std::string tokenizerFile = path + "tokenizer.json";
auto tokenizer = json11::Json::parse(ReadAllFile(tokenizerFile), error);
for (auto &it : tokenizer["model"]["vocab"].object_items()) {
model->weight.AddTokenizerWord(it.first, it.second.int_value(), 1.0f);
}
std::map<std::string, int> spTokens;
for (auto &it : tokenizer["added_tokens"].array_items()) {
spTokens[it["content"].string_value()] = it["id"].int_value();
}
model->weight.tokenizer.SetSpecialTokens(spTokens);

if (!tokenizer["decoder"].is_null() && !tokenizer["decoder"]["type"].is_null() &&
tokenizer["decoder"]["type"].string_value() == "ByteLevel") {
model->weight.tokenizer.byteAsChar = true;
}
} else if (tokenizerClass == "ChatGLM4Tokenizer") {
// GLM4御用的分词
model->bot_role = " ";
std::vector <std::string> lines, line;
SplitString(ReadAllFile(path + "tokenizer.model"), {'\r', '\n'}, lines);
for (int i = 0; i < lines.size(); i++) {
SplitString(lines[i], {' '}, line);
model->weight.AddTokenizerWord(Base64Decode(line[0]), atoi(line[1].c_str()), 1.0f);
}
std::map<std::string, int> spTokens;
for (auto &it : tokenizerConfig["added_tokens_decoder"].object_items()) {
spTokens[it.second["content"].string_value()] = atoi(it.first.c_str());
}
model->weight.tokenizer.SetSpecialTokens(spTokens);
((ChatGLMModel*)model)->gmask_token_id = model->weight.tokenizer.GetTokenId("[gMASK]");
((ChatGLMModel*)model)->bos_token_id = model->weight.tokenizer.GetTokenId("<sop>");
((ChatGLMModel*)model)->tokenizerClass = tokenizerClass;

// ChatGLM采用拼接token的方法,需要强行指定分割词的TokenID
model->pre_prompt = "";
model->user_role = ("<FLM_FIX_TOKEN_" + std::to_string(model->weight.tokenizer.GetTokenId("<|user|>")) + ">\n");
model->bot_role = ("<FLM_FIX_TOKEN_" + std::to_string(model->weight.tokenizer.GetTokenId("<|assistant|>")) + ">\n");
model->history_sep = "";
model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN;
} else {
ErrorInFastLLM("Unsupport tokenizer_class: " + tokenizerClass);
}
}

// 从hf文件夹读取分词
std::unique_ptr<basellm> CreateLLMTokenizerFromHF(const std::string &modelPath) {
std::string error;
std::string path = modelPath;
if (path.back() != '/' || path.back() != '\\') {
path += "/";
}
std::string configFile = path + "config.json";
auto config = json11::Json::parse(ReadAllFile(configFile), error);
basellm *model = CreateModelWithType(config["model_type"].string_value());
LoadLLMTokenizerFromHFToModel(path, model);
return std::unique_ptr<fastllm::basellm> (model);
}

// 从hf文件夹读取,仅支持safetensor格式的模型
std::unique_ptr <basellm> CreateLLMModelFromHF(const std::string &modelPath,
DataType linearDataType, int groupCnt) {
Expand Down Expand Up @@ -428,56 +494,7 @@ namespace fastllm {
}

// 3. 读取分词
std::string tokenizerConfigFile = path + "tokenizer_config.json";
auto tokenizerConfig = json11::Json::parse(ReadAllFile(tokenizerConfigFile), error);
model->weight.tokenizer.SetTokenizerConfig(tokenizerConfig);
std::string tokenizerClass = tokenizerConfig["tokenizer_class"].string_value();
if (tokenizerClass == "PreTrainedTokenizerFast" || tokenizerClass == "Qwen2Tokenizer") {
// PreTrainedTokenizerFast
std::string tokenizerFile = path + "tokenizer.json";
auto tokenizer = json11::Json::parse(ReadAllFile(tokenizerFile), error);
auto tokenizerModel = tokenizer["model"];
auto vocab = tokenizerModel["vocab"];
for (auto &it : vocab.object_items()) {
model->weight.AddTokenizerWord(it.first, it.second.int_value(), 1.0f);
}
std::map<std::string, int> spTokens;
for (auto &it : tokenizer["added_tokens"].array_items()) {
spTokens[it["content"].string_value()] = it["id"].int_value();
}
model->weight.tokenizer.SetSpecialTokens(spTokens);

if (!tokenizer["decoder"].is_null() && !tokenizer["decoder"]["type"].is_null() &&
tokenizer["decoder"]["type"].string_value() == "ByteLevel") {
model->weight.tokenizer.byteAsChar = true;
}
} else if (tokenizerClass == "ChatGLM4Tokenizer") {
// GLM4御用的分词
model->bot_role = " ";
std::vector <std::string> lines, line;
SplitString(ReadAllFile(path + "tokenizer.model"), {'\r', '\n'}, lines);
for (int i = 0; i < lines.size(); i++) {
SplitString(lines[i], {' '}, line);
model->weight.AddTokenizerWord(Base64Decode(line[0]), atoi(line[1].c_str()), 1.0f);
}
std::map<std::string, int> spTokens;
for (auto &it : tokenizerConfig["added_tokens_decoder"].object_items()) {
spTokens[it.second["content"].string_value()] = atoi(it.first.c_str());
}
model->weight.tokenizer.SetSpecialTokens(spTokens);
((ChatGLMModel*)model)->gmask_token_id = model->weight.tokenizer.GetTokenId("[gMASK]");
((ChatGLMModel*)model)->bos_token_id = model->weight.tokenizer.GetTokenId("<sop>");
((ChatGLMModel*)model)->tokenizerClass = tokenizerClass;

// ChatGLM采用拼接token的方法,需要强行指定分割词的TokenID
model->pre_prompt = "";
model->user_role = ("<FLM_FIX_TOKEN_" + std::to_string(model->weight.tokenizer.GetTokenId("<|user|>")) + ">\n");
model->bot_role = ("<FLM_FIX_TOKEN_" + std::to_string(model->weight.tokenizer.GetTokenId("<|assistant|>")) + ">\n");
model->history_sep = "";
model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN;
} else {
ErrorInFastLLM("Unsupport tokenizer_class: " + tokenizerClass);
}
LoadLLMTokenizerFromHFToModel(path, model);

// 4. 读取权重
auto tensors = safeTensors.GetSortedItemNames();
Expand Down
110 changes: 107 additions & 3 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
fastllm_lib.create_llm_model_fromhf.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int]
fastllm_lib.create_llm_model_fromhf.restype = ctypes.c_int

fastllm_lib.create_llm_tokenizer_fromhf.argtypes = [ctypes.c_char_p]
fastllm_lib.create_llm_tokenizer_fromhf.restype = ctypes.c_int

fastllm_lib.add_eos_token.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int]

fastllm_lib.token_decode.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_char_p]
Expand Down Expand Up @@ -132,6 +135,104 @@ def from_hf(model,
"float32": 0,
}

class tokenizer:
def __init__ (self, path : str,
id : int = -99999,
system_prompt : str = ""):
self.systemp_prompt = system_prompt
if (id != -99999):
self.model = id
else:
if os.path.isfile(path):
self.model = fastllm_lib.create_llm_tokenizer(path.encode());
elif os.path.isdir(path):
self.model = fastllm_lib.create_llm_tokenizer_fromhf(path.encode());
else:
print("path error: ", path);
exit(0)
self.thread_local_obj = threading.local()
self.tokenizer_decode_token_cache = None

def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
tokenize: bool = True,
#padding: bool = False,
#truncation: bool = False,
#max_length: Optional[int] = None,
#return_tensors: Optional[Union[str, TensorType]] = None,
#return_dict: bool = False,
#tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[str, List[int], List[str], List[List[int]]]:
if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
):
conversations = conversation
is_batched = True
else:
conversations = [conversation]
is_batched = False
strs = []
for conversation in conversations:
messages = []
for it in conversation:
if it["role"] == "system":
messages += ["system", it["content"]]
for it in conversation:
if it["role"] != "system":
messages += [it["role"], it["content"]]
poss = []
lens = []
all = b''
for i in range(len(messages)):
messages[i] = messages[i].encode()
all += messages[i]
poss.append(0 if i == 0 else poss[-1] + lens[-1])
lens.append(len(messages[i]))
strs.append(fastllm_lib.apply_chat_template(self.model, all, len(messages), (ctypes.c_int * len(poss))(*poss), (ctypes.c_int * len(lens))(*lens)).decode())
if (is_batched):
return strs
else:
return strs[0]

def encode(
self,
text: str,
#text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
#add_special_tokens: bool = True,
#padding: Union[bool, str, PaddingStrategy] = False,
#truncation: Union[bool, str, TruncationStrategy] = None,
#max_length: Optional[int] = None,
#stride: int = 0,
#return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> List[int]:
content = text
output_buffer_init_len = 1024
if "tokenizer_encode_string__output_buffer" not in dir(self.thread_local_obj) or self.thread_local_obj.tokenizer_encode_string__output_buffer is None:
self.thread_local_obj.tokenizer_encode_string__output_buffer = (ctypes.c_int * output_buffer_init_len)()

buffer = self.thread_local_obj.tokenizer_encode_string__output_buffer
buffer_len = len(buffer)
result_len = fastllm_lib.token_encode_string(self.model, content.encode(), buffer_len, buffer)
if result_len > buffer_len:
if result_len > 10240:
# 要处理的数据过长,使用一次性的buffer
temp_buffer = (ctypes.c_int * result_len)()
ret = fastllm_lib.token_encode_string(self.model, content.encode(), result_len, temp_buffer)
return [i for i in temp_buffer]
else:
# 扩展buffer大小
new_buffer_len = round(math.ceil(result_len / 1024.0)) * 1024
buffer = (ctypes.c_int * new_buffer_len)()
self.thread_local_obj.tokenizer_encode_string__output_buffer = buffer
result_len = fastllm_lib.token_encode_string(self.model, content.encode(), new_buffer_len, buffer)

return [buffer[i] for i in range(result_len)]

class model:
def __init__ (self, path : str,
id : int = -99999,
Expand Down Expand Up @@ -214,7 +315,7 @@ def build_tokenizer_decode_token_cache(self):
cache_dict[token_id] = self.tokenizer_decode_token(token_id)

self.tokenizer_decode_token_cache = cache_dict

def tokenizer_encode_string(self, content: str) -> List[int]:
output_buffer_init_len = 1024
if "tokenizer_encode_string__output_buffer" not in dir(self.thread_local_obj) or self.thread_local_obj.tokenizer_encode_string__output_buffer is None:
Expand All @@ -237,7 +338,10 @@ def tokenizer_encode_string(self, content: str) -> List[int]:
result_len = fastllm_lib.token_encode_string(self.model, content.encode(), new_buffer_len, buffer)

return [buffer[i] for i in range(result_len)]


def encode(self, content: str) -> List[int]:
return self.tokenizer_encode_string(content)

def tokenizer_decode_token(self, token_id: int) -> bytes:
if self.tokenizer_decode_token_cache is not None:
cache_result = self.tokenizer_decode_token_cache.get(token_id)
Expand Down Expand Up @@ -454,4 +558,4 @@ def set_save_history(self, save: bool):
fastllm_lib.set_save_history(self.model, save);

def set_atype(self, atype: str):
fastllm_lib.set_model_atype(self.model, str(atype).encode());
fastllm_lib.set_model_atype(self.model, str(atype).encode());
8 changes: 8 additions & 0 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ extern "C" {
return id;
}

DLL_EXPORT int create_llm_tokenizer_fromhf(char *path) {
models.locker.lock();
int id = models.models.size();
models.models[id] = fastllm::CreateLLMTokenizerFromHF(path);
models.locker.unlock();
return id;
}

DLL_EXPORT int create_empty_llm_model(char *type) {
models.locker.lock();
int id = models.models.size();
Expand Down

0 comments on commit fd7eb1a

Please sign in to comment.