Skip to content

Commit

Permalink
Merge pull request #328 from lockmatrix/pr-pytools0924
Browse files Browse the repository at this point in the history
pytools 增加tokenizer接口 与stream_response_raw
  • Loading branch information
ztxz16 authored Sep 27, 2023
2 parents 435583e + e8b1313 commit e9b2f90
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 0 deletions.
107 changes: 107 additions & 0 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ctypes;
import math
import os;
import threading
from typing import Optional, Tuple, Union, List, Callable, Dict, Any;

import platform
Expand All @@ -11,6 +13,12 @@
fastllm_lib.create_llm_model.argtypes = [ctypes.c_char_p]
fastllm_lib.create_llm_model.restype = ctypes.c_int

fastllm_lib.token_decode.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_char_p]
fastllm_lib.token_decode.restype = ctypes.c_int

fastllm_lib.token_encode_string.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_int)]
fastllm_lib.token_encode_string.restype = ctypes.c_int

fastllm_lib.launch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_void_p,
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
ctypes.c_float, ctypes.c_float, ctypes.c_bool]
Expand Down Expand Up @@ -102,6 +110,16 @@ def __init__ (self, path : str,
self.model = fastllm_lib.create_llm_model(path.encode());
self.direct_query = False;

# 为了减少重复申请释放buffer对象而使用的线程局部存储区对象池
self.thread_local_obj = threading.local()
self.thread_local_obj.tokenizer_encode_string__output_buffer = None
self.thread_local_obj.tokenizer_decode_token__output_buffer = None

# tokenizer_decode_token 输出结果的静态缓存,手工触发构建
# 由于token数量有限且不太多,所以缓存该结果来减少调用较为适合。
# 不做成自动缓存是为了避免在多线程调用的时候对缓存dict加锁,同时也为不同场景提供选择空间
self.tokenizer_decode_token_cache = None

def get_prompt(self,
query: str,
history: List[Tuple[str, str]] = None) -> str:
Expand All @@ -119,6 +137,68 @@ def save(self, path : str):
def eval(self):
pass;

def build_tokenizer_decode_token_cache(self):
if self.tokenizer_decode_token_cache is not None:
return

cache_dict = dict()
vocab_size = fastllm_lib.get_tokenizer_vocab_size(self.model)
for token_id in range(vocab_size):
cache_dict[token_id] = self.tokenizer_decode_token(token_id)

self.tokenizer_decode_token_cache = cache_dict

def tokenizer_encode_string(self, content: str) -> List[int]:
output_buffer_init_len = 1024
if self.thread_local_obj.tokenizer_encode_string__output_buffer is None:
self.thread_local_obj.tokenizer_encode_string__output_buffer = (ctypes.c_int * output_buffer_init_len)()

buffer = self.thread_local_obj.tokenizer_encode_string__output_buffer
buffer_len = len(buffer)
result_len = fastllm_lib.token_encode_string(self.model, content.encode(), buffer_len, buffer)
if result_len > buffer_len:
if result_len > 10240:
# 要处理的数据过长,使用一次性的buffer
temp_buffer = (ctypes.c_int * result_len)()
ret = fastllm_lib.token_encode_string(self.model, content.encode(), result_len, temp_buffer)
return [i for i in temp_buffer]
else:
# 扩展buffer大小
new_buffer_len = round(math.ceil(result_len / 1024.0)) * 1024
buffer = (ctypes.c_int * new_buffer_len)()
self.thread_local_obj.tokenizer_encode_string__output_buffer = buffer
result_len = fastllm_lib.token_encode_string(self.model, content.encode(), new_buffer_len, buffer)

return [buffer[i] for i in range(result_len)]

def tokenizer_decode_token(self, token_id: int) -> bytes:
if self.tokenizer_decode_token_cache is not None:
cache_result = self.tokenizer_decode_token_cache.get(token_id)
if cache_result is not None:
return cache_result

output_buffer_init_len = 256
if self.thread_local_obj.tokenizer_decode_token__output_buffer is None:
self.thread_local_obj.tokenizer_decode_token__output_buffer = ctypes.create_string_buffer(output_buffer_init_len)

buffer = self.thread_local_obj.tokenizer_decode_token__output_buffer
ret = fastllm_lib.token_decode(self.model, token_id, len(buffer), buffer)
if ret > 0:
# buffer长度不够,扩展buffer大小
new_buffer_len = round(math.ceil(ret / 16.0)) * 16
buffer = ctypes.create_string_buffer(new_buffer_len)
self.thread_local_obj.tokenizer_decode_token__output_buffer = buffer
ret = fastllm_lib.token_decode(self.model, token_id, len(buffer), buffer)
assert ret == 0

buffer_bytes = buffer.raw
result_len = len(buffer_bytes)
for i in range(len(buffer_bytes)):
if buffer_bytes[i] == 0:
result_len = i
break
return buffer_bytes[:result_len]

def response_logits(self,
query: str,
history: List[Tuple[str, str]] = None,
Expand Down Expand Up @@ -190,6 +270,33 @@ def stream_response(self,
res += cur;
yield res;

def stream_response_raw(self,
input_tokens: List[int],
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
one_by_one = True
):
handle = fastllm_lib.launch_response_llm_model(self.model, len(input_tokens),
(ctypes.c_int * len(input_tokens))(*input_tokens),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False))

# 可能遇到长尾char需要多个token才能够生成,所以只返回bytes,string.decode策略交给外部
# 方便统计输出token数量,和控制不完整utf8时候解码的逻辑

total_bytes = b''
while True:
cur_token = fastllm_lib.fetch_response_llm_model(self.model, handle)
if cur_token == -1:
break

cur_bytes = self.tokenizer_decode_token(cur_token)

if one_by_one:
yield cur_bytes
else:
total_bytes += cur_bytes
yield total_bytes

def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs):
if (not(history)):
Expand Down
28 changes: 28 additions & 0 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,34 @@ extern "C" {
return;
}

DLL_EXPORT int token_decode(int modelId, int tokenId, int output_buffer_len, char *output_buffer) {
// 正常时候返回0,输出buffer长度不足时返回输出的bytes数量,包含末尾的\0
if(tokenId == -1) {
output_buffer[0] = '\0';
return 0;
}
auto model = models.GetModel(modelId);
std::string s = model->weight.tokenizer.DecodeTokens(std::vector <int> {tokenId});
if(s.length() + 1 > output_buffer_len) {
return (int)s.length() + 1;
}
memcpy(output_buffer, s.c_str(), s.length() + 1);
return 0;
}

DLL_EXPORT int token_encode_string(int modelId, char *content, int output_buffer_len, int *output_buffer) {
// 返回写入到output_buffer中的数量。当output不足时候,只输出对应的部分
auto model = models.GetModel(modelId);
auto v = model->weight.tokenizer.Encode(content);
for (int i = 0; i < v.Count(0); i++) {
if(i >= output_buffer_len) {
break;
}
output_buffer[i] = (int)((float*)v.cpuData)[i];
}
return (int)v.Count(0);
}

DLL_EXPORT void add_dict_llm_model(int modelId, char *key, char *value) {
auto model = models.GetModel(modelId);
model->weight.AddDict(key, value);
Expand Down

0 comments on commit e9b2f90

Please sign in to comment.