Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update issues THUDM#618 使用tools时无法stream流式输出的问题 #619

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions basic_demo/glm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

EventSourceResponse.DEFAULT_PING_INTERVAL = 1000

MAX_MODEL_LENGTH = 8192
MAX_MODEL_LENGTH = 8192

@asynccontextmanager
async def lifespan(app: FastAPI):
Expand Down Expand Up @@ -444,23 +444,35 @@ async def predict_stream(model_id, gen_params):
system_fingerprint = generate_id('fp_', 9)
tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {}
delta_text = ""
delta_confirming_texts = []
confirm_tool_state = 'un_confirm' if tools else 'none'
# 带有tools时可以确认是否调用工具的最大字符长度 = 工具名最大长度 + 可能的前面有“\n”、后面“\n{”共3个字符。
max_confirm_tool_length = len(max(tools, len)) + 3 if tools else 0
async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"]
delta_text += decoded_unicode[len(output):]
if confirm_tool_state == 'un_confirm':
delta_confirming_texts.append(decoded_unicode[len(output):])

output = decoded_unicode
lines = output.strip().split("\n")

# 检查是否为工具
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。

if not is_function_call and len(lines) >= 2:
if confirm_tool_state == 'un_confirm' and len(lines) >= 2 and lines[1].startswith("{"):
first_line = lines[0].strip()
if first_line in tools:
is_function_call = True
function_name = first_line
delta_text = lines[1]
confirm_tool_state == 'confirmed'
else:
confirm_tool_state == 'none'

# 当传入tools时,经过大模型输出几轮后,已经可以确认不需要调用工具了
if confirm_tool_state == 'un_confirm' and max_confirm_tool_length < len(delta_text):
confirm_tool_state == 'none'
# 工具调用返回
if is_function_call:
if not has_send_first_chunk:
Expand Down Expand Up @@ -524,7 +536,7 @@ async def predict_stream(model_id, gen_params):
yield chunk.model_dump_json(exclude_unset=True)

# 用户请求了 Function Call 但是框架还没确定是否为Function Call
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
elif confirm_tool_state == 'un_confirm':
continue

# 常规返回
Expand Down Expand Up @@ -552,6 +564,29 @@ async def predict_stream(model_id, gen_params):
yield chunk.model_dump_json(exclude_unset=True)
has_send_first_chunk = True

for text in delta_confirming_texts:
message = DeltaMessage(
content=text,
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(
model=model_id,
id=response_id,
choices=[choice_data],
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
delta_confirming_texts = []
delta_text = ""

message = DeltaMessage(
content=delta_text,
role="assistant",
Expand Down Expand Up @@ -613,7 +648,7 @@ async def predict_stream(model_id, gen_params):
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)

finish_reason = 'stop'
message = DeltaMessage(
content=delta_text,
Expand Down