From 996bc942742fdfebd6253eac0dd8f05e22abbeb4 Mon Sep 17 00:00:00 2001 From: qbc Date: Mon, 4 Mar 2024 20:08:59 +0800 Subject: [PATCH] Add tongyi model wrapper (#46) --- docs/sphinx_doc/source/tutorial/203-model.md | 4 +- examples/conversation/conversation.py | 2 +- setup.py | 1 + src/agentscope/models/__init__.py | 6 + src/agentscope/models/openai_model.py | 1 + src/agentscope/models/post_model.py | 1 - src/agentscope/models/tongyi_model.py | 217 +++++++++++++++++++ tests/model_test.py | 2 +- 8 files changed, 229 insertions(+), 5 deletions(-) create mode 100644 src/agentscope/models/tongyi_model.py diff --git a/docs/sphinx_doc/source/tutorial/203-model.md b/docs/sphinx_doc/source/tutorial/203-model.md index 3bf012fbf..dc83ff758 100644 --- a/docs/sphinx_doc/source/tutorial/203-model.md +++ b/docs/sphinx_doc/source/tutorial/203-model.md @@ -17,7 +17,7 @@ where the model configs could be a list of dict: { "config_name": "gpt-4-temperature-0.0", "model_type": "openai", - "model": "gpt-4", + "model_name": "gpt-4", "api_key": "xxx", "organization": "xxx", "generate_args": { @@ -27,7 +27,7 @@ where the model configs could be a list of dict: { "config_name": "dall-e-3-size-1024x1024", "model_type": "openai_dall_e", - "model": "dall-e-3", + "model_name": "dall-e-3", "api_key": "xxx", "organization": "xxx", "generate_args": { diff --git a/examples/conversation/conversation.py b/examples/conversation/conversation.py index 5cd721a25..ff926ae76 100644 --- a/examples/conversation/conversation.py +++ b/examples/conversation/conversation.py @@ -10,7 +10,7 @@ { "model_type": "openai", "config_name": "gpt-3.5-turbo", - "model": "gpt-3.5-turbo", + "model_name": "gpt-3.5-turbo", "api_key": "xxx", # Load from env if not provided "organization": "xxx", # Load from env if not provided "generate_args": { diff --git a/setup.py b/setup.py index 441011fb8..7f3a74ec9 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ "Flask==3.0.0", "Flask-Cors==4.0.0", "Flask-SocketIO==5.3.6", + "dashscope", ] distribute_requires = minimal_requires + rpc_requires diff --git a/src/agentscope/models/__init__.py b/src/agentscope/models/__init__.py index 66bd69b44..e7f6e7b5d 100644 --- a/src/agentscope/models/__init__.py +++ b/src/agentscope/models/__init__.py @@ -17,6 +17,10 @@ OpenAIDALLEWrapper, OpenAIEmbeddingWrapper, ) +from .tongyi_model import ( + TongyiWrapper, + TongyiChatWrapper, +) __all__ = [ @@ -31,6 +35,8 @@ "load_model_by_config_name", "read_model_configs", "clear_model_configs", + "TongyiWrapper", + "TongyiChatWrapper", ] _MODEL_CONFIGS: dict[str, dict] = {} diff --git a/src/agentscope/models/openai_model.py b/src/agentscope/models/openai_model.py index 9c47e886c..0d8455a17 100644 --- a/src/agentscope/models/openai_model.py +++ b/src/agentscope/models/openai_model.py @@ -57,6 +57,7 @@ def __init__( """ if model_name is None: model_name = config_name + logger.warning("model_name is not set, use config_name instead.") super().__init__( config_name=config_name, model_name=model_name, diff --git a/src/agentscope/models/post_model.py b/src/agentscope/models/post_model.py index 0addce7c6..663fd137d 100644 --- a/src/agentscope/models/post_model.py +++ b/src/agentscope/models/post_model.py @@ -143,7 +143,6 @@ def __call__(self, input_: str, **kwargs: Any) -> ModelResponse: break if i < self.max_retries: - # av logger.warning( f"Failed to call the model with " f"requests.codes == {response.status_code}, retry " diff --git a/src/agentscope/models/tongyi_model.py b/src/agentscope/models/tongyi_model.py new file mode 100644 index 000000000..74712fc13 --- /dev/null +++ b/src/agentscope/models/tongyi_model.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +"""Model wrapper for Tongyi models""" +from typing import Any + +try: + import dashscope +except ImportError: + dashscope = None + +from loguru import logger + +from .model import ModelWrapperBase, ModelResponse + +from ..utils.monitor import MonitorFactory +from ..utils.monitor import get_full_name +from ..utils import QuotaExceededError +from ..constants import _DEFAULT_API_BUDGET + + +class TongyiWrapper(ModelWrapperBase): + """The model wrapper for Tongyi API.""" + + def __init__( + self, + config_name: str, + model_name: str = None, + api_key: str = None, + generate_args: dict = None, + budget: float = _DEFAULT_API_BUDGET, + **kwargs: Any, + ) -> None: + """Initialize the Tongyi wrapper. + + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in Tongyi API. + api_key (`str`, default `None`): + The API key for Tongyi API. + generate_args (`dict`, default `None`): + The extra keyword arguments used in Tongyi api generation, + e.g. `temperature`, `seed`. + budget (`float`, default `None`): + The total budget using this model. Set to `None` means no + limit. + """ + if model_name is None: + model_name = config_name + logger.warning("model_name is not set, use config_name instead.") + super().__init__( + config_name=config_name, + model_name=model_name, + generate_args=generate_args, + budget=budget, + **kwargs, + ) + if dashscope is None: + raise ImportError( + "Cannot find dashscope package in current python environment.", + ) + + self.model = model_name + self.generate_args = generate_args or {} + + self.api_key = api_key + dashscope.api_key = self.api_key + self.max_length = None + + # Set monitor accordingly + self.monitor = None + self.budget = budget + self._register_budget() + self._register_default_metrics() + + def _register_budget(self) -> None: + self.monitor = MonitorFactory.get_monitor() + self.monitor.register_budget( + model_name=self.model, + value=self.budget, + prefix=self.model, + ) + + def _register_default_metrics(self) -> None: + """Register metrics to the monitor.""" + raise NotImplementedError( + "The _register_default_metrics function is not Implemented.", + ) + + def _metric(self, metric_name: str) -> str: + """Add the class name and model name as prefix to the metric name. + + Args: + metric_name (`str`): + The metric name. + + Returns: + `str`: Metric name of this wrapper. + """ + return get_full_name(name=metric_name, prefix=self.model) + + +class TongyiChatWrapper(TongyiWrapper): + """The model wrapper for Tongyi's chat API.""" + + model_type: str = "tongyi_chat" + + def _register_default_metrics(self) -> None: + # Set monitor accordingly + # TODO: set quota to the following metrics + self.monitor = MonitorFactory.get_monitor() + self.monitor.register( + self._metric("prompt_tokens"), + metric_unit="token", + ) + self.monitor.register( + self._metric("completion_tokens"), + metric_unit="token", + ) + self.monitor.register( + self._metric("total_tokens"), + metric_unit="token", + ) + + def __call__( + self, + messages: list, + **kwargs: Any, + ) -> ModelResponse: + """Processes a list of messages to construct a payload for the Tongyi + API call. It then makes a request to the Tongyi API and returns the + response. This method also updates monitoring metrics based on the + API response. + + Each message in the 'messages' list can contain text content and + optionally an 'image_urls' key. If 'image_urls' is provided, + it is expected to be a list of strings representing URLs to images. + These URLs will be transformed to a suitable format for the Tongyi + API, which might involve converting local file paths to data URIs. + + Args: + messages (`list`): + A list of messages to process. + **kwargs (`Any`): + The keyword arguments to Tongyi chat completions API, + e.g. `temperature`, `max_tokens`, `top_p`, etc. Please refer to + + for more detailed arguments. + + Returns: + `ModelResponse`: + The response text in text field, and the raw response in + raw field. + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved for + `_response_parse_decorator` to parse and check the response + generated by model wrapper. Their usages are listed as follows: + - `parse_func` is a callable function used to parse and check + the response generated by the model, which takes the response + as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + """ + + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: checking messages + if not all("role" in msg and "content" in msg for msg in messages): + raise ValueError( + "Each message in the 'messages' list must contain a 'role' " + "and 'content' key for Tongyi API.", + ) + + # For Tongyi model, the "role" value of the first and the last message + # must be "user" + if len(messages) > 0: + messages[0]["role"] = "user" + messages[-1]["role"] = "user" + + # step3: forward to generate response + response = dashscope.Generation.call( + model=self.model, + messages=messages, + result_format="message", # set the result to be "message" format. + **kwargs, + ) + + # step4: record the api invocation if needed + self._save_model_invocation( + arguments={ + "model": self.model, + "messages": messages, + **kwargs, + }, + json_response=response, + ) + + # step5: update monitor accordingly + try: + self.monitor.update( + response.usage, + prefix=self.model, + ) + except QuotaExceededError as e: + # TODO: optimize quota exceeded error handling process + logger.error(e.message) + + # step6: return response + return ModelResponse( + text=response.output["choices"][0]["message"]["content"], + raw=response, + ) diff --git a/tests/model_test.py b/tests/model_test.py index 4c247daec..54eeb2c00 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -52,7 +52,7 @@ def test_load_model_configs(self) -> None: { "model_type": "openai", "config_name": "gpt-4", - "model": "gpt-4", + "model_name": "gpt-4", "api_key": "xxx", "organization": "xxx", "generate_args": {"temperature": 0.5},