Skip to content

Commit

Permalink
api server带上token数量信息
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 23, 2024
1 parent 2c1d2f6 commit 3b101e7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
5 changes: 3 additions & 2 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,11 +610,12 @@ def stop_token_ctypes(self, stop_token_ids):
else:
return ctypes.c_int(len(stop_token_ids)), (ctypes.c_int * len(stop_token_ids))(*stop_token_ids)

def get_input_token_len(self, query: str, history: List[Tuple[str, str]] = None) -> int:
prompt = query if self.direct_query else self.get_prompt(query, history);
def get_input_token_len(self, conversation: List[Dict[str, str]], add_generation_prompt = True) -> int:
if (self.hf_tokenizer != None and hasattr(self.hf_tokenizer, "chat_template") and self.hf_tokenizer.chat_template != ""):
prompt = self.hf_tokenizer.apply_chat_template(conversation, add_generation_prompt = add_generation_prompt, tokenize = False)
return len(self.hf_tokenizer.encode(prompt))
else:
prompt = self.apply_chat_template(conversation)
return len(self.encode(prompt))

def response_logits(self,
Expand Down
6 changes: 2 additions & 4 deletions tools/fastllm_pytools/openai_server/fastllm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ async def create_chat_completion(
return error_check_ret

query:str = ""
history:List[Tuple[str, str]] = []
if request.prompt:
request.messages.append({"role": "user", "content": request.prompt})
try:
Expand All @@ -101,7 +100,6 @@ async def create_chat_completion(

conversation.extend(messages)

# fastllm 样例中history只能是一问一答, system promt 暂时不支持
if len(conversation) == 0:
raise Exception("Empty msg")
messages = []
Expand All @@ -120,11 +118,11 @@ async def create_chat_completion(
frequency_penalty = request.frequency_penalty

max_length = request.max_tokens if request.max_tokens else 8192
input_token_len = 0; # self.model.get_input_token_len(query, history)
input_token_len = self.model.get_input_token_len(messages)
#logging.info(request)
logging.info(f"fastllm input message: {messages}")
#logging.info(f"input tokens: {input_token_len}")
# stream_response 中的结果不包含token的统计信息

result_generator = self.model.stream_response_async(messages,
max_length = max_length, do_sample = True,
top_p = request.top_p, top_k = request.top_k, temperature = request.temperature,
Expand Down

0 comments on commit 3b101e7

Please sign in to comment.