From 4c808a82a3052fcf0bcb015f676a97cef2ce3b9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BC=A8=E7=BC=A8?= Date: Wed, 11 Sep 2024 14:30:32 +0800 Subject: [PATCH] feat: track token usage when stream chat (#372) * feat: track token usage when stream chat * chore: update test gemini bot token --- assistant/src/Assistant/index.md | 2 +- server/agent/base.py | 14 ++++++++ server/agent/llm/clients/openai.py | 52 ++++++++++++++++-------------- 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/assistant/src/Assistant/index.md b/assistant/src/Assistant/index.md index 075a133c..7baffbbb 100644 --- a/assistant/src/Assistant/index.md +++ b/assistant/src/Assistant/index.md @@ -22,7 +22,7 @@ import { Assistant } from '@petercatai/assistant'; export default () => ( ); diff --git a/server/agent/base.py b/server/agent/base.py index 8698259a..08aa18ae 100644 --- a/server/agent/base.py +++ b/server/agent/base.py @@ -143,6 +143,20 @@ async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]: ensure_ascii=False, ) yield f"data: {json_output}\n\n" + elif kind == "on_chat_model_end": + content = event["data"]["output"]["generations"][0][0][ + "message" + ].usage_metadata + if content: + json_output = json.dumps( + { + "id": event["run_id"], + "type": "usage", + **content, + }, + ensure_ascii=False, + ) + yield f"data: {json_output}\n\n" elif kind == "on_tool_start": children_value = event["data"].get("input", {}) json_output = json.dumps( diff --git a/server/agent/llm/clients/openai.py b/server/agent/llm/clients/openai.py index 49397920..08d4ca09 100644 --- a/server/agent/llm/clients/openai.py +++ b/server/agent/llm/clients/openai.py @@ -10,30 +10,32 @@ OPEN_API_KEY = get_env_variable("OPENAI_API_KEY") + @register_llm_client("openai") class OpenAIClient(BaseLLMClient): - _client: ChatOpenAI - - def __init__(self, - temperature: Optional[int] = 0.2, - max_tokens: Optional[int] = 1500, - streaming: Optional[bool] = False, - api_key: Optional[str] = OPEN_API_KEY - ): - self._client = ChatOpenAI( - model_name="gpt-4o", - temperature=temperature, - streaming=streaming, - max_tokens=max_tokens, - openai_api_key=api_key, - ) - - def get_client(self): - return self._client - - def get_tools(self, tools: List[Any]): - return [convert_to_openai_tool(tool) for tool in tools] - - def parse_content(self, content: List[MessageContent]): - print(f"parse_content: {content}") - return content \ No newline at end of file + _client: ChatOpenAI + + def __init__( + self, + temperature: Optional[int] = 0.2, + max_tokens: Optional[int] = 1500, + streaming: Optional[bool] = False, + api_key: Optional[str] = OPEN_API_KEY, + ): + self._client = ChatOpenAI( + model_name="gpt-4o", + temperature=temperature, + streaming=streaming, + max_tokens=max_tokens, + openai_api_key=api_key, + stream_usage=True, + ) + + def get_client(self): + return self._client + + def get_tools(self, tools: List[Any]): + return [convert_to_openai_tool(tool) for tool in tools] + + def parse_content(self, content: List[MessageContent]): + return content