From 9a84c470b9246fa5048b4823d7116c9053b30255 Mon Sep 17 00:00:00 2001 From: chenglj <453308407@qq.com> Date: Wed, 30 Oct 2024 16:22:50 +0800 Subject: [PATCH] =?UTF-8?q?update=20issues=20THUDM#618=20=E4=BD=BF?= =?UTF-8?q?=E7=94=A8tools=E6=97=B6=E6=97=A0=E6=B3=95stream=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E8=BE=93=E5=87=BA=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basic_demo/glm_server.py | 45 +++++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/basic_demo/glm_server.py b/basic_demo/glm_server.py index 2ae8b22..fec07c3 100644 --- a/basic_demo/glm_server.py +++ b/basic_demo/glm_server.py @@ -19,7 +19,7 @@ EventSourceResponse.DEFAULT_PING_INTERVAL = 1000 -MAX_MODEL_LENGTH = 8192 +MAX_MODEL_LENGTH = 8192 @asynccontextmanager async def lifespan(app: FastAPI): @@ -444,23 +444,35 @@ async def predict_stream(model_id, gen_params): system_fingerprint = generate_id('fp_', 9) tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {} delta_text = "" + delta_confirming_texts = [] + confirm_tool_state = 'un_confirm' if tools else 'none' + # 带有tools时可以确认是否调用工具的最大字符长度 = 工具名最大长度 + 可能的前面有“\n”、后面“\n{”共3个字符。 + max_confirm_tool_length = len(max(tools, len)) + 3 if tools else 0 async for new_response in generate_stream_glm4(gen_params): decoded_unicode = new_response["text"] delta_text += decoded_unicode[len(output):] + if confirm_tool_state == 'un_confirm': + delta_confirming_texts.append(decoded_unicode[len(output):]) + output = decoded_unicode lines = output.strip().split("\n") # 检查是否为工具 # 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。 ##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。 - - if not is_function_call and len(lines) >= 2: + if confirm_tool_state == 'un_confirm' and len(lines) >= 2 and lines[1].startswith("{"): first_line = lines[0].strip() if first_line in tools: is_function_call = True function_name = first_line delta_text = lines[1] + confirm_tool_state == 'confirmed' + else: + confirm_tool_state == 'none' + # 当传入tools时,经过大模型输出几轮后,已经可以确认不需要调用工具了 + if confirm_tool_state == 'un_confirm' and max_confirm_tool_length < len(delta_text): + confirm_tool_state == 'none' # 工具调用返回 if is_function_call: if not has_send_first_chunk: @@ -524,7 +536,7 @@ async def predict_stream(model_id, gen_params): yield chunk.model_dump_json(exclude_unset=True) # 用户请求了 Function Call 但是框架还没确定是否为Function Call - elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call: + elif confirm_tool_state == 'un_confirm': continue # 常规返回 @@ -552,6 +564,29 @@ async def predict_stream(model_id, gen_params): yield chunk.model_dump_json(exclude_unset=True) has_send_first_chunk = True + for text in delta_confirming_texts: + message = DeltaMessage( + content=text, + role="assistant", + function_call=None, + ) + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=message, + finish_reason=finish_reason + ) + chunk = ChatCompletionResponse( + model=model_id, + id=response_id, + choices=[choice_data], + created=created_time, + system_fingerprint=system_fingerprint, + object="chat.completion.chunk" + ) + yield chunk.model_dump_json(exclude_unset=True) + delta_confirming_texts = [] + delta_text = "" + message = DeltaMessage( content=delta_text, role="assistant", @@ -613,7 +648,7 @@ async def predict_stream(model_id, gen_params): object="chat.completion.chunk" ) yield chunk.model_dump_json(exclude_unset=True) - + finish_reason = 'stop' message = DeltaMessage( content=delta_text,