diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index d7bbb271..7ebdc297 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -1,5 +1,7 @@ import ctypes; +import math import os; +import threading from typing import Optional, Tuple, Union, List, Callable, Dict, Any; import platform @@ -108,6 +110,16 @@ def __init__ (self, path : str, self.model = fastllm_lib.create_llm_model(path.encode()); self.direct_query = False; + # 为了减少重复申请释放buffer对象而使用的线程局部存储区对象池 + self.thread_local_obj = threading.local() + self.thread_local_obj.tokenizer_encode_string__output_buffer = None + self.thread_local_obj.tokenizer_decode_token__output_buffer = None + + # tokenizer_decode_token 输出结果的静态缓存,手工触发构建 + # 由于token数量有限且不太多,所以缓存该结果来减少调用较为适合。 + # 不做成自动缓存是为了避免在多线程调用的时候对缓存dict加锁,同时也为不同场景提供选择空间 + self.tokenizer_decode_token_cache = None + def get_prompt(self, query: str, history: List[Tuple[str, str]] = None) -> str: @@ -125,20 +137,67 @@ def save(self, path : str): def eval(self): pass; - def tokenizer_encode_string(self, content): - buffer_len = len(content.encode()) - buffer = (ctypes.c_int * buffer_len)() + def build_tokenizer_decode_token_cache(self): + if self.tokenizer_decode_token_cache is not None: + return + + cache_dict = dict() + vocab_size = fastllm_lib.get_tokenizer_vocab_size(self.model) + for token_id in range(vocab_size): + cache_dict[token_id] = self.tokenizer_decode_token(token_id) + + self.tokenizer_decode_token_cache = cache_dict + + def tokenizer_encode_string(self, content: str) -> List[int]: + output_buffer_init_len = 1024 + if self.thread_local_obj.tokenizer_encode_string__output_buffer is None: + self.thread_local_obj.tokenizer_encode_string__output_buffer = (ctypes.c_int * output_buffer_init_len)() + + buffer = self.thread_local_obj.tokenizer_encode_string__output_buffer + buffer_len = len(buffer) result_len = fastllm_lib.token_encode_string(self.model, content.encode(), buffer_len, buffer) - assert 0 < result_len < buffer_len - return [i for i in buffer][:result_len] - - def tokenizer_decode_token(self, token_id): - assert isinstance(token_id, int) - buffer_len = 256 - buffer = ctypes.create_string_buffer(buffer_len) - ret = fastllm_lib.token_decode(self.model, token_id, buffer_len, buffer) - assert ret == 0 - return buffer.raw + if result_len > buffer_len: + if result_len > 10240: + # 要处理的数据过长,使用一次性的buffer + temp_buffer = (ctypes.c_int * result_len)() + ret = fastllm_lib.token_encode_string(self.model, content.encode(), result_len, temp_buffer) + return [i for i in temp_buffer] + else: + # 扩展buffer大小 + new_buffer_len = round(math.ceil(result_len / 1024.0)) * 1024 + buffer = (ctypes.c_int * new_buffer_len)() + self.thread_local_obj.tokenizer_encode_string__output_buffer = buffer + result_len = fastllm_lib.token_encode_string(self.model, content.encode(), new_buffer_len, buffer) + + return [buffer[i] for i in range(result_len)] + + def tokenizer_decode_token(self, token_id: int) -> bytes: + if self.tokenizer_decode_token_cache is not None: + cache_result = self.tokenizer_decode_token_cache.get(token_id) + if cache_result is not None: + return cache_result + + output_buffer_init_len = 256 + if self.thread_local_obj.tokenizer_decode_token__output_buffer is None: + self.thread_local_obj.tokenizer_decode_token__output_buffer = ctypes.create_string_buffer(output_buffer_init_len) + + buffer = self.thread_local_obj.tokenizer_decode_token__output_buffer + ret = fastllm_lib.token_decode(self.model, token_id, len(buffer), buffer) + if ret > 0: + # buffer长度不够,扩展buffer大小 + new_buffer_len = round(math.ceil(ret / 16.0)) * 16 + buffer = ctypes.create_string_buffer(new_buffer_len) + self.thread_local_obj.tokenizer_decode_token__output_buffer = buffer + ret = fastllm_lib.token_decode(self.model, token_id, len(buffer), buffer) + assert ret == 0 + + buffer_bytes = buffer.raw + result_len = len(buffer_bytes) + for i in range(len(buffer_bytes)): + if buffer_bytes[i] == 0: + result_len = i + break + return buffer_bytes[:result_len] def response_logits(self, query: str, diff --git a/tools/src/pytools.cpp b/tools/src/pytools.cpp index 39391cce..c490e94f 100644 --- a/tools/src/pytools.cpp +++ b/tools/src/pytools.cpp @@ -118,29 +118,31 @@ extern "C" { } DLL_EXPORT int token_decode(int modelId, int tokenId, int output_buffer_len, char *output_buffer) { - if(tokenId == -1) - return -1; + // 正常时候返回0,输出buffer长度不足时返回输出的bytes数量,包含末尾的\0 + if(tokenId == -1) { + output_buffer[0] = '\0'; + return 0; + } auto model = models.GetModel(modelId); std::string s = model->weight.tokenizer.DecodeTokens(std::vector {tokenId}); if(s.length() + 1 > output_buffer_len) { - return 1; + return (int)s.length() + 1; } memcpy(output_buffer, s.c_str(), s.length() + 1); return 0; } DLL_EXPORT int token_encode_string(int modelId, char *content, int output_buffer_len, int *output_buffer) { - // 返回写入到output_buffer中的数量 + // 返回写入到output_buffer中的数量。当output不足时候,只输出对应的部分 auto model = models.GetModel(modelId); auto v = model->weight.tokenizer.Encode(content); - int i = 0; - for (; i < v.Count(0); i++) { + for (int i = 0; i < v.Count(0); i++) { if(i >= output_buffer_len) { break; } output_buffer[i] = (int)((float*)v.cpuData)[i]; } - return i; + return (int)v.Count(0); } DLL_EXPORT void add_dict_llm_model(int modelId, char *key, char *value) {