Skip to content

Commit

Permalink
optimize pytools tokenizer_*
Browse files Browse the repository at this point in the history
  • Loading branch information
lockmatrix committed Sep 24, 2023
1 parent 800885e commit e8b1313
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 20 deletions.
85 changes: 72 additions & 13 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ctypes;
import math
import os;
import threading
from typing import Optional, Tuple, Union, List, Callable, Dict, Any;

import platform
Expand Down Expand Up @@ -108,6 +110,16 @@ def __init__ (self, path : str,
self.model = fastllm_lib.create_llm_model(path.encode());
self.direct_query = False;

# 为了减少重复申请释放buffer对象而使用的线程局部存储区对象池
self.thread_local_obj = threading.local()
self.thread_local_obj.tokenizer_encode_string__output_buffer = None
self.thread_local_obj.tokenizer_decode_token__output_buffer = None

# tokenizer_decode_token 输出结果的静态缓存,手工触发构建
# 由于token数量有限且不太多,所以缓存该结果来减少调用较为适合。
# 不做成自动缓存是为了避免在多线程调用的时候对缓存dict加锁,同时也为不同场景提供选择空间
self.tokenizer_decode_token_cache = None

def get_prompt(self,
query: str,
history: List[Tuple[str, str]] = None) -> str:
Expand All @@ -125,20 +137,67 @@ def save(self, path : str):
def eval(self):
pass;

def tokenizer_encode_string(self, content):
buffer_len = len(content.encode())
buffer = (ctypes.c_int * buffer_len)()
def build_tokenizer_decode_token_cache(self):
if self.tokenizer_decode_token_cache is not None:
return

cache_dict = dict()
vocab_size = fastllm_lib.get_tokenizer_vocab_size(self.model)
for token_id in range(vocab_size):
cache_dict[token_id] = self.tokenizer_decode_token(token_id)

self.tokenizer_decode_token_cache = cache_dict

def tokenizer_encode_string(self, content: str) -> List[int]:
output_buffer_init_len = 1024
if self.thread_local_obj.tokenizer_encode_string__output_buffer is None:
self.thread_local_obj.tokenizer_encode_string__output_buffer = (ctypes.c_int * output_buffer_init_len)()

buffer = self.thread_local_obj.tokenizer_encode_string__output_buffer
buffer_len = len(buffer)
result_len = fastllm_lib.token_encode_string(self.model, content.encode(), buffer_len, buffer)
assert 0 < result_len < buffer_len
return [i for i in buffer][:result_len]

def tokenizer_decode_token(self, token_id):
assert isinstance(token_id, int)
buffer_len = 256
buffer = ctypes.create_string_buffer(buffer_len)
ret = fastllm_lib.token_decode(self.model, token_id, buffer_len, buffer)
assert ret == 0
return buffer.raw
if result_len > buffer_len:
if result_len > 10240:
# 要处理的数据过长,使用一次性的buffer
temp_buffer = (ctypes.c_int * result_len)()
ret = fastllm_lib.token_encode_string(self.model, content.encode(), result_len, temp_buffer)
return [i for i in temp_buffer]
else:
# 扩展buffer大小
new_buffer_len = round(math.ceil(result_len / 1024.0)) * 1024
buffer = (ctypes.c_int * new_buffer_len)()
self.thread_local_obj.tokenizer_encode_string__output_buffer = buffer
result_len = fastllm_lib.token_encode_string(self.model, content.encode(), new_buffer_len, buffer)

return [buffer[i] for i in range(result_len)]

def tokenizer_decode_token(self, token_id: int) -> bytes:
if self.tokenizer_decode_token_cache is not None:
cache_result = self.tokenizer_decode_token_cache.get(token_id)
if cache_result is not None:
return cache_result

output_buffer_init_len = 256
if self.thread_local_obj.tokenizer_decode_token__output_buffer is None:
self.thread_local_obj.tokenizer_decode_token__output_buffer = ctypes.create_string_buffer(output_buffer_init_len)

buffer = self.thread_local_obj.tokenizer_decode_token__output_buffer
ret = fastllm_lib.token_decode(self.model, token_id, len(buffer), buffer)
if ret > 0:
# buffer长度不够,扩展buffer大小
new_buffer_len = round(math.ceil(ret / 16.0)) * 16
buffer = ctypes.create_string_buffer(new_buffer_len)
self.thread_local_obj.tokenizer_decode_token__output_buffer = buffer
ret = fastllm_lib.token_decode(self.model, token_id, len(buffer), buffer)
assert ret == 0

buffer_bytes = buffer.raw
result_len = len(buffer_bytes)
for i in range(len(buffer_bytes)):
if buffer_bytes[i] == 0:
result_len = i
break
return buffer_bytes[:result_len]

def response_logits(self,
query: str,
Expand Down
16 changes: 9 additions & 7 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,29 +118,31 @@ extern "C" {
}

DLL_EXPORT int token_decode(int modelId, int tokenId, int output_buffer_len, char *output_buffer) {
if(tokenId == -1)
return -1;
// 正常时候返回0,输出buffer长度不足时返回输出的bytes数量,包含末尾的\0
if(tokenId == -1) {
output_buffer[0] = '\0';
return 0;
}
auto model = models.GetModel(modelId);
std::string s = model->weight.tokenizer.DecodeTokens(std::vector <int> {tokenId});
if(s.length() + 1 > output_buffer_len) {
return 1;
return (int)s.length() + 1;
}
memcpy(output_buffer, s.c_str(), s.length() + 1);
return 0;
}

DLL_EXPORT int token_encode_string(int modelId, char *content, int output_buffer_len, int *output_buffer) {
// 返回写入到output_buffer中的数量
// 返回写入到output_buffer中的数量。当output不足时候,只输出对应的部分
auto model = models.GetModel(modelId);
auto v = model->weight.tokenizer.Encode(content);
int i = 0;
for (; i < v.Count(0); i++) {
for (int i = 0; i < v.Count(0); i++) {
if(i >= output_buffer_len) {
break;
}
output_buffer[i] = (int)((float*)v.cpuData)[i];
}
return i;
return (int)v.Count(0);
}

DLL_EXPORT void add_dict_llm_model(int modelId, char *key, char *value) {
Expand Down

0 comments on commit e8b1313

Please sign in to comment.