From d2317838dadaa732736c49b76c47fd6db49b1b1d Mon Sep 17 00:00:00 2001 From: "raoha.rh" Date: Thu, 21 Nov 2024 16:17:19 +0800 Subject: [PATCH] feat: support config temperature in bot fix: lint --- server/agent/bot/__init__.py | 7 ++- server/agent/llm/__init__.py | 80 +++++++++++++++++++----------- server/agent/llm/base.py | 45 +++++++++-------- server/agent/llm/clients/gemini.py | 71 ++++++++++++++------------ server/agent/llm/clients/openai.py | 6 ++- server/bot/builder.py | 1 - server/core/models/bot.py | 3 ++ 7 files changed, 128 insertions(+), 85 deletions(-) diff --git a/server/agent/bot/__init__.py b/server/agent/bot/__init__.py index 2d958a05..474088ff 100644 --- a/server/agent/bot/__init__.py +++ b/server/agent/bot/__init__.py @@ -10,7 +10,12 @@ class Bot: def __init__(self, bot: BotModel, llm_token: LLMTokenLike): self._bot = bot self._llm_token = llm_token - self._llm = LLM(llm_token=llm_token) + self._llm = LLM( + llm_token=llm_token, + temperature=bot.temperature, + n=bot.n, + top_p=bot.top_p, + ) @property def id(self): diff --git a/server/agent/llm/__init__.py b/server/agent/llm/__init__.py index 1743a6d7..acc9c933 100644 --- a/server/agent/llm/__init__.py +++ b/server/agent/llm/__init__.py @@ -3,55 +3,77 @@ from typing import Dict, Optional, Type from agent.llm.base import BaseLLMClient -class LLMTokenLike(): + +class LLMTokenLike: token: str llm: str -llm_client_registry: Dict[str, Type['BaseLLMClient']] = {} + +llm_client_registry: Dict[str, Type["BaseLLMClient"]] = {} + def register_llm_client(name: str): """Decorator to register a new LLM client class.""" + def decorator(cls): if name in llm_client_registry: raise ValueError(f"Client '{name}' is already registered.") llm_client_registry[name] = cls return cls + return decorator + def get_registered_llm_client(): - return llm_client_registry + return llm_client_registry -def import_clients(directory: str = 'clients'): + +def import_clients(directory: str = "clients"): """Dynamically import all Python modules in the given directory.""" # 获取当前文件的绝对路径 current_dir = os.path.dirname(os.path.abspath(__file__)) clients_dir = os.path.join(current_dir, directory) - + for filename in os.listdir(clients_dir): - if filename.endswith('.py') and not filename.startswith('__'): + if filename.endswith(".py") and not filename.startswith("__"): module_name = f"agent.llm.{directory}.{filename[:-3]}" # 去掉 .py 后缀 importlib.import_module(module_name) -class LLM(): - llm_token: LLMTokenLike - client: Optional[BaseLLMClient] - - def __init__(self, llm_token: LLMTokenLike): - self._llm_token = llm_token - self._client = self.get_llm_client(llm_token.llm, api_key=llm_token.token) - - def get_llm_client( - self, - llm: str = 'openai', - api_key: Optional[str | None] = None, - temperature: Optional[int] = 0.2, - max_tokens: Optional[int] = 1500, - streaming: Optional[bool] = False - ) -> BaseLLMClient: - - """Get an LLM client based on the specified name.""" - if llm in llm_client_registry: - client_class = llm_client_registry[llm] - return client_class(temperature=temperature, api_key=api_key, streaming=streaming, max_tokens=max_tokens) - - return None \ No newline at end of file + +class LLM: + llm_token: LLMTokenLike + client: Optional[BaseLLMClient] + + def __init__( + self, + llm_token: LLMTokenLike, + temperature: Optional[float] = 0.2, + n: Optional[int] = 1, + top_p: Optional[float] = None + ): + self._llm_token = llm_token + self._client = self.get_llm_client(llm_token.llm, api_key=llm_token.token, temperature=temperature, n=n, top_p=top_p) + + def get_llm_client( + self, + llm: str = "openai", + api_key: Optional[str | None] = None, + temperature: Optional[float] = 0.2, + n: Optional[int] = 1, + top_p: Optional[float] = None, + max_tokens: Optional[int] = 1500, + streaming: Optional[bool] = False, + ) -> BaseLLMClient: + """Get an LLM client based on the specified name.""" + if llm in llm_client_registry: + client_class = llm_client_registry[llm] + return client_class( + temperature=temperature, + n=n, + top_p=top_p, + api_key=api_key, + streaming=streaming, + max_tokens=max_tokens, + ) + + return None diff --git a/server/agent/llm/base.py b/server/agent/llm/base.py index 0e6da881..b4e3d475 100644 --- a/server/agent/llm/base.py +++ b/server/agent/llm/base.py @@ -1,29 +1,30 @@ - from abc import abstractmethod from typing import Any, Dict, List, Optional from langchain_core.language_models import BaseChatModel from petercat_utils.data_class import MessageContent -class BaseLLMClient(): - def __init__(self, - temperature: Optional[int] = 0.2, - max_tokens: Optional[int] = 1500, - streaming: Optional[bool] = False, - api_key: Optional[str] = '' - ): - pass - - @abstractmethod - def get_client() -> BaseChatModel: - pass - - @abstractmethod - def get_tools(self, tool: List[Any]) -> list[Dict[str, Any]]: - pass - - @abstractmethod - def parse_content(self, content: List[MessageContent]) -> List[MessageContent]: - pass - +class BaseLLMClient: + def __init__( + self, + temperature: Optional[float] = 0.2, + n: Optional[int] = 1, + top_p: Optional[float] = None, + max_tokens: Optional[int] = 1500, + streaming: Optional[bool] = False, + api_key: Optional[str] = "", + ): + pass + + @abstractmethod + def get_client() -> BaseChatModel: + pass + + @abstractmethod + def get_tools(self, tool: List[Any]) -> list[Dict[str, Any]]: + pass + + @abstractmethod + def parse_content(self, content: List[MessageContent]) -> List[MessageContent]: + pass diff --git a/server/agent/llm/clients/gemini.py b/server/agent/llm/clients/gemini.py index 6635c807..cbcac752 100644 --- a/server/agent/llm/clients/gemini.py +++ b/server/agent/llm/clients/gemini.py @@ -10,38 +10,47 @@ GEMINI_API_KEY = get_env_variable("GEMINI_API_KEY") + def parse_gemini_input(message: MessageContent): - match message.type: - case "image_url": - return ImageRawURLContentBlock(image_url=message.image_url.url, type="image_url") - case _: - return message + match message.type: + case "image_url": + return ImageRawURLContentBlock( + image_url=message.image_url.url, type="image_url" + ) + case _: + return message + @register_llm_client("gemini") class GeminiClient(BaseLLMClient): - _client: ChatOpenAI - - def __init__(self, - temperature: Optional[int] = 0.2, - max_tokens: Optional[int] = 1500, - streaming: Optional[bool] = False, - api_key: Optional[str] = GEMINI_API_KEY, - ): - self._client = ChatGoogleGenerativeAI( - model="gemini-1.5-flash", - temperature=temperature, - streaming=streaming, - max_tokens=max_tokens, - google_api_key=api_key, - ) - - def get_client(self): - return self._client - - def get_tools(self, tools: List[Any]): - return [convert_to_genai_function_declarations(tool) for tool in tools] - - def parse_content(self, content: List[MessageContent]): - result = [parse_gemini_input(message=message) for message in content] - print(f"parse_content, content={content}, result={result}") - return result + _client: ChatOpenAI + + def __init__( + self, + temperature: Optional[float] = 0.2, + n: Optional[int] = 1, + top_p: Optional[float] = None, + max_tokens: Optional[int] = 1500, + streaming: Optional[bool] = False, + api_key: Optional[str] = GEMINI_API_KEY, + ): + self._client = ChatGoogleGenerativeAI( + model="gemini-1.5-flash", + temperature=temperature, + top_p=top_p, + n=n, + streaming=streaming, + max_tokens=max_tokens, + google_api_key=api_key, + ) + + def get_client(self): + return self._client + + def get_tools(self, tools: List[Any]): + return [convert_to_genai_function_declarations(tool) for tool in tools] + + def parse_content(self, content: List[MessageContent]): + result = [parse_gemini_input(message=message) for message in content] + print(f"parse_content, content={content}, result={result}") + return result diff --git a/server/agent/llm/clients/openai.py b/server/agent/llm/clients/openai.py index 08d4ca09..e1ea597a 100644 --- a/server/agent/llm/clients/openai.py +++ b/server/agent/llm/clients/openai.py @@ -17,7 +17,9 @@ class OpenAIClient(BaseLLMClient): def __init__( self, - temperature: Optional[int] = 0.2, + temperature: Optional[float] = 0.2, + n: Optional[int] = 1, + top_p: Optional[float] = None, max_tokens: Optional[int] = 1500, streaming: Optional[bool] = False, api_key: Optional[str] = OPEN_API_KEY, @@ -25,6 +27,8 @@ def __init__( self._client = ChatOpenAI( model_name="gpt-4o", temperature=temperature, + n=n, + top_p=top_p, streaming=streaming, max_tokens=max_tokens, openai_api_key=api_key, diff --git a/server/bot/builder.py b/server/bot/builder.py index 4ca13b54..6b96983a 100644 --- a/server/bot/builder.py +++ b/server/bot/builder.py @@ -1,4 +1,3 @@ -import json from typing import List, Optional from github import Github diff --git a/server/core/models/bot.py b/server/core/models/bot.py index 297edce6..040aff57 100644 --- a/server/core/models/bot.py +++ b/server/core/models/bot.py @@ -16,6 +16,9 @@ class BotModel(BaseModel): token_id: Optional[str] = "" created_at: datetime = datetime.now() domain_whitelist: Optional[list[str]] = [] + temperature: Optional[float] = 0.2 + n: Optional[int] = 1 + top_p: Optional[float] class RepoBindBotConfigVO(BaseModel):