From 81c3f3e43a31ec5de115714dfab8be06d6a0e325 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Fri, 17 May 2024 11:34:43 +0800 Subject: [PATCH] =?UTF-8?q?openai=20server=E4=B8=AD=E5=8A=A0=E5=A6=82await?= =?UTF-8?q?=20sleep(0)=EF=BC=8C=E4=BF=AE=E5=A4=8D=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- example/openai_server/fastllm_completion.py | 8 ++++++++ example/openai_server/protocal/openai_protocol.py | 5 +++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/example/openai_server/fastllm_completion.py b/example/openai_server/fastllm_completion.py index edc4c9ad..1922ef01 100644 --- a/example/openai_server/fastllm_completion.py +++ b/example/openai_server/fastllm_completion.py @@ -1,3 +1,4 @@ +import asyncio import logging import json import traceback @@ -88,6 +89,8 @@ async def create_chat_completion( query:str = "" history:List[Tuple[str, str]] = [] + if request.prompt: + request.messages.append({"role": "user", "content": request.prompt}) try: conversation: List[ConversationMessage] = [] for m in request.messages: @@ -214,6 +217,7 @@ async def chat_completion_stream_generator( model=model_name) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" + await asyncio.sleep(0) first_iteration = False # 2. content部分 @@ -233,6 +237,7 @@ async def chat_completion_stream_generator( model=model_name) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" + await asyncio.sleep(0) # 3. 结束标志 choice_data = ChatCompletionResponseStreamChoice( @@ -249,10 +254,13 @@ async def chat_completion_stream_generator( data = chunk.model_dump_json(exclude_unset=True, exclude_none=True) yield f"data: {data}\n\n" + await asyncio.sleep(0) except ValueError as e: data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" + await asyncio.sleep(0) yield "data: [DONE]\n\n" + await asyncio.sleep(0) def create_streaming_error_response( self, diff --git a/example/openai_server/protocal/openai_protocol.py b/example/openai_server/protocal/openai_protocol.py index 5e7a94ab..5631a258 100644 --- a/example/openai_server/protocal/openai_protocol.py +++ b/example/openai_server/protocal/openai_protocol.py @@ -59,11 +59,12 @@ class LogProbs(BaseModel): class ChatCompletionRequest(BaseModel): model: str - messages: Union[ + messages: Optional[Union[ str, List[Dict[str, str]], List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]], - ] + ]] = [] + prompt: Optional[str] = "" temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 top_k: Optional[int] = -1