Skip to content

Commit

Permalink
ftllm.server和ftllm.webui增加system_prompt支持
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 18, 2024
1 parent d38ae99 commit 2a0e9d0
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 77 deletions.
149 changes: 105 additions & 44 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,31 @@ def __init__ (self, path : str,
# 由于token数量有限且不太多,所以缓存该结果来减少调用较为适合。
# 不做成自动缓存是为了避免在多线程调用的时候对缓存dict加锁,同时也为不同场景提供选择空间
self.tokenizer_decode_token_cache = None

def apply_chat_template(
self,
conversation: List[Dict[str, str]],
chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
**kwargs,
) -> str:
messages = []
for it in conversation:
if it["role"] == "system":
messages += ["system", it["content"]]
for it in conversation:
if it["role"] != "system":
messages += [it["role"], it["content"]]
poss = []
lens = []
all = b''
for i in range(len(messages)):
messages[i] = messages[i].encode()
all += messages[i]
poss.append(0 if i == 0 else poss[-1] + lens[-1])
lens.append(len(messages[i]))
str = fastllm_lib.apply_chat_template(self.model, all, len(messages), (ctypes.c_int * len(poss))(*poss), (ctypes.c_int * len(lens))(*lens)).decode()
return str

def generate(
self,
Expand Down Expand Up @@ -611,32 +636,57 @@ def response(self,
return ret;

def stream_response(self,
query: str,
query: Union[str, List[Dict[str, str]]],
history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
one_by_one = True, stop_token_ids: List[int] = None):
one_by_one = True, stop_token_ids: List[int] = None, add_generation_prompt = True):
conversation = None
if (isinstance(query, List)):
conversation = query
if (self.hf_tokenizer != None and hasattr(self.hf_tokenizer, "chat_template") and self.hf_tokenizer.chat_template != ""):
lastlen = 0
for cur in self.stream_chat(tokenizer = self.hf_tokenizer,
query = query,
history = history,
max_length = max_length,
do_sample = do_sample,
top_p = top_p, top_k = top_k,
temperature = temperature,
repeat_penalty = repeat_penalty,
stop_token_ids = stop_token_ids):
if one_by_one:
ret = cur[0][lastlen:]
if (ret.encode().find(b'\xef\xbf\xbd') == -1):
lastlen = len(cur[0])
yield ret
else:
yield ""
tokenizer = self.hf_tokenizer
type = None
if (hasattr(tokenizer, "name")
and tokenizer.name == "GLMTokenizer"
and hasattr(tokenizer, "build_chat_input")):
type = "ChatGLM3"
if (not(history)):
history = [];
if (type == "ChatGLM3"):
input = tokenizer.build_chat_input(query, history=history)["input_ids"].reshape(-1).tolist()
else:
prompt = ""
if (conversation != None and len(conversation) != 0):
prompt = tokenizer.apply_chat_template(conversation, add_generation_prompt = add_generation_prompt, tokenize = False)
else:
prompt = query if self.direct_query else self.get_prompt(query, history)
input = tokenizer.encode(prompt)
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False, stop_token_len, stop_token_list)
tokens = [];
while True:
if not(fastllm_lib.can_fetch_response_llm_model(self.model, handle)):
continue
cur = fastllm_lib.fetch_response_llm_model(self.model, handle)
if (cur == -1):
break
tokens.append(cur)
ret = tokenizer.decode(tokens)
if (ret.encode().find(b'\xef\xbf\xbd') == -1):
tokens.clear()
yield ret
else:
yield cur[0]
yield ""
if len(tokens) > 0:
yield tokenizer.decode(tokens)
else:
prompt = query if self.direct_query else self.get_prompt(query, history);
prompt = ""
if (conversation != None and len(conversation) != 0):
prompt = self.apply_chat_template(conversation)
else:
prompt = query if self.direct_query else self.get_prompt(query, history)
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids);
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
Expand Down Expand Up @@ -677,10 +727,13 @@ def add_cache(self,
exit(0)

async def stream_response_async(self,
query: str,
query: Union[str, List[Dict[str, str]]],
history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
one_by_one = True, stop_token_ids: List[int] = None):
one_by_one = True, stop_token_ids: List[int] = None, add_generation_prompt = True):
conversation = None
if (isinstance(query, List)):
conversation = query
if (self.hf_tokenizer != None and hasattr(self.hf_tokenizer, "chat_template") and self.hf_tokenizer.chat_template != ""):
tokenizer = self.hf_tokenizer
type = None
Expand All @@ -693,12 +746,16 @@ async def stream_response_async(self,
if (type == "ChatGLM3"):
input = tokenizer.build_chat_input(query, history=history)["input_ids"].reshape(-1).tolist()
else:
prompt = query if self.direct_query else self.get_prompt(query, history);
input = tokenizer.encode(prompt);
prompt = ""
if (conversation != None and len(conversation) != 0):
prompt = tokenizer.apply_chat_template(conversation, add_generation_prompt = add_generation_prompt, tokenize = False)
else:
prompt = query if self.direct_query else self.get_prompt(query, history)
input = tokenizer.encode(prompt)
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False, stop_token_len, stop_token_list);
False, stop_token_len, stop_token_list)
tokens = [];
while True:
if not(fastllm_lib.can_fetch_response_llm_model(self.model, handle)):
Expand All @@ -717,38 +774,42 @@ async def stream_response_async(self,
if len(tokens) > 0:
yield tokenizer.decode(tokens)
else:
prompt = query if self.direct_query else self.get_prompt(query, history);
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids);
prompt = ""
if (conversation != None and len(conversation) != 0):
prompt = self.apply_chat_template(conversation)
else:
prompt = query if self.direct_query else self.get_prompt(query, history)
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False),
stop_token_len, stop_token_list);
res = "";
ret = b'';
fail_cnt = 0;
stop_token_len, stop_token_list)
res = ""
ret = b''
fail_cnt = 0
while True:
if not(fastllm_lib.can_fetch_response_llm_model(self.model, handle)):
await asyncio.sleep(0)
continue
ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle);
cur = "";
ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle)
cur = ""
try:
cur = ret.decode();
ret = b'';
cur = ret.decode()
ret = b''
except:
fail_cnt += 1;
fail_cnt += 1
if (fail_cnt == 20):
break;
break
else:
continue;
fail_cnt = 0;
continue
fail_cnt = 0
if (cur == "<flmeos>"):
break;
break
if one_by_one:
yield cur;
yield cur
else:
res += cur;
yield res;
res += cur
yield res

def stream_response_raw(self,
input_tokens: List[int],
Expand Down
36 changes: 5 additions & 31 deletions tools/fastllm_pytools/openai_server/fastllm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,34 +104,9 @@ async def create_chat_completion(
# fastllm 样例中history只能是一问一答, system promt 暂时不支持
if len(conversation) == 0:
raise Exception("Empty msg")

for i in range(len(conversation)):
msg = conversation[i]
if msg.role == "system":
# fastllm 暂时不支持system prompt
continue
elif msg.role == "user":
if i + 1 < len(conversation):
next_msg = conversation[i + 1]
if next_msg.role == "assistant":
history.append((msg.content, next_msg.content))
else:
# 只能是user、assistant、user、assistant的格式
raise Exception("fastllm requires that the prompt words must appear alternately in the roles of user and assistant.")
elif msg.role == "assistant":
if i - 1 < 0:
raise Exception("fastllm Not Support assistant prompt in first message")
else:
pre_msg = conversation[i - 1]
if pre_msg.role != "user":
raise Exception("In FastLLM, The message role before the assistant msg must be user")
else:
raise NotImplementedError(f"prompt role {msg.role } not supported yet")

last_msg = conversation[-1]
if last_msg.role != "user":
raise Exception("last msg role must be user")
query = last_msg.content
messages = []
for msg in conversation:
messages.append({"role": msg.role, "content": msg.content})

except Exception as e:
logging.error("Error in applying chat template from request: %s", e)
Expand All @@ -147,11 +122,10 @@ async def create_chat_completion(
max_length = request.max_tokens if request.max_tokens else 8192
input_token_len = 0; # self.model.get_input_token_len(query, history)
#logging.info(request)
logging.info(f"fastllm input: {query}")
logging.info(f"fastllm history: {history}")
logging.info(f"fastllm input message: {messages}")
#logging.info(f"input tokens: {input_token_len}")
# stream_response 中的结果不包含token的统计信息
result_generator = self.model.stream_response_async(query, history,
result_generator = self.model.stream_response_async(messages,
max_length = max_length, do_sample = True,
top_p = request.top_p, top_k = request.top_k, temperature = request.temperature,
repeat_penalty = frequency_penalty, one_by_one = True)
Expand Down
13 changes: 11 additions & 2 deletions tools/fastllm_pytools/web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ def parse_args():
def get_model():
args = parse_args()
model = make_normal_llm_model(args)
model.set_verbose(True)
return model

if "messages" not in st.session_state:
st.session_state.messages = []

system_prompt = st.sidebar.text_input("system_prompt", "")
max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 8192, 512, step = 1)
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step = 0.01)
top_k = st.sidebar.slider("top_k", 1, 50, 1, step = 1)
Expand All @@ -55,8 +57,15 @@ def get_model():
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
for chunk in model.stream_response(prompt,
st.session_state.messages,
messages = []
if system_prompt != "":
messages.append({"role": "system", "content": system_prompt})
for his in st.session_state.messages:
messages.append({"role": "user", "content": his[0]})
messages.append({"role": "assistant", "content": his[1]})
messages.append({"role": "user", "content": prompt})

for chunk in model.stream_response(messages,
max_length = max_new_tokens,
top_k = top_k,
top_p = top_p,
Expand Down

0 comments on commit 2a0e9d0

Please sign in to comment.