Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tongyi model wrapper #46

Merged
merged 5 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/203-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ where the model configs could be a list of dict:
{
"config_name": "gpt-4-temperature-0.0",
"model_type": "openai",
"model": "gpt-4",
"model_name": "gpt-4",
"api_key": "xxx",
"organization": "xxx",
"generate_args": {
Expand All @@ -27,7 +27,7 @@ where the model configs could be a list of dict:
{
"config_name": "dall-e-3-size-1024x1024",
"model_type": "openai_dall_e",
"model": "dall-e-3",
"model_name": "dall-e-3",
"api_key": "xxx",
"organization": "xxx",
"generate_args": {
Expand Down
2 changes: 1 addition & 1 deletion examples/conversation/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
{
"model_type": "openai",
"config_name": "gpt-3.5-turbo",
"model": "gpt-3.5-turbo",
"model_name": "gpt-3.5-turbo",
"api_key": "xxx", # Load from env if not provided
"organization": "xxx", # Load from env if not provided
"generate_args": {
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"Flask==3.0.0",
"Flask-Cors==4.0.0",
"Flask-SocketIO==5.3.6",
"dashscope",
]

distribute_requires = minimal_requires + rpc_requires
Expand Down
6 changes: 6 additions & 0 deletions src/agentscope/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
OpenAIDALLEWrapper,
OpenAIEmbeddingWrapper,
)
from .tongyi_model import (
TongyiWrapper,
TongyiChatWrapper,
)


__all__ = [
Expand All @@ -31,6 +35,8 @@
"load_model_by_config_name",
"read_model_configs",
"clear_model_configs",
"TongyiWrapper",
"TongyiChatWrapper",
]

_MODEL_CONFIGS: dict[str, dict] = {}
Expand Down
1 change: 0 additions & 1 deletion src/agentscope/models/post_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def __call__(self, input_: str, **kwargs: Any) -> ModelResponse:
break

if i < self.max_retries:
# av
logger.warning(
f"Failed to call the model with "
f"requests.codes == {response.status_code}, retry "
Expand Down
216 changes: 216 additions & 0 deletions src/agentscope/models/tongyi_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
"""Model wrapper for Tongyi models"""
from typing import Any

try:
import dashscope
except ImportError:
dashscope = None

from loguru import logger

from .model import ModelWrapperBase, ModelResponse

from ..utils.monitor import MonitorFactory
from ..utils.monitor import get_full_name
from ..utils import QuotaExceededError
from ..constants import _DEFAULT_API_BUDGET


class TongyiWrapper(ModelWrapperBase):
"""The model wrapper for Tongyi API."""

def __init__(
self,
config_name: str,
model_name: str = None,
api_key: str = None,
generate_args: dict = None,
budget: float = _DEFAULT_API_BUDGET,
**kwargs: Any,
) -> None:
"""Initialize the Tongyi wrapper.

Args:
config_name (`str`):
The name of the model config.
model_name (`str`, default `None`):
The name of the model to use in Tongyi API.
api_key (`str`, default `None`):
The API key for Tongyi API.
generate_args (`dict`, default `None`):
The extra keyword arguments used in Tongyi api generation,
e.g. `temperature`, `seed`.
budget (`float`, default `None`):
The total budget using this model. Set to `None` means no
limit.
"""
if model_name is None:
model_name = config_name
super().__init__(
config_name=config_name,
model_name=model_name,
generate_args=generate_args,
budget=budget,
**kwargs,
)
if dashscope is None:
raise ImportError(
"Cannot find dashscope package in current python environment.",
)

self.model = model_name
self.generate_args = generate_args or {}

self.api_key = api_key
dashscope.api_key = self.api_key
self.max_length = None

# Set monitor accordingly
self.monitor = None
self.budget = budget
self._register_budget()
self._register_default_metrics()

def _register_budget(self) -> None:
self.monitor = MonitorFactory.get_monitor()
self.monitor.register_budget(
model_name=self.model,
value=self.budget,
prefix=self.model,
)

def _register_default_metrics(self) -> None:
"""Register metrics to the monitor."""
raise NotImplementedError(
"The _register_default_metrics function is not Implemented.",
)

def _metric(self, metric_name: str) -> str:
"""Add the class name and model name as prefix to the metric name.

Args:
metric_name (`str`):
The metric name.

Returns:
`str`: Metric name of this wrapper.
"""
return get_full_name(name=metric_name, prefix=self.model)


class TongyiChatWrapper(TongyiWrapper):
"""The model wrapper for Tongyi's chat API."""

model_type: str = "tongyi_chat"

def _register_default_metrics(self) -> None:
# Set monitor accordingly
# TODO: set quota to the following metrics
self.monitor = MonitorFactory.get_monitor()
self.monitor.register(
self._metric("prompt_tokens"),
metric_unit="token",
)
self.monitor.register(
self._metric("completion_tokens"),
metric_unit="token",
)
self.monitor.register(
self._metric("total_tokens"),
metric_unit="token",
)

def __call__(
self,
messages: list,
**kwargs: Any,
) -> ModelResponse:
"""Processes a list of messages to construct a payload for the Tongyi
API call. It then makes a request to the Tongyi API and returns the
response. This method also updates monitoring metrics based on the
API response.

Each message in the 'messages' list can contain text content and
optionally an 'image_urls' key. If 'image_urls' is provided,
it is expected to be a list of strings representing URLs to images.
These URLs will be transformed to a suitable format for the Tongyi
API, which might involve converting local file paths to data URIs.

Args:
messages (`list`):
A list of messages to process.
**kwargs (`Any`):
The keyword arguments to Tongyi chat completions API,
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please refer to

for more detailed arguments.

Returns:
`ModelResponse`:
The response text in text field, and the raw response in
raw field.

Note:
`parse_func`, `fault_handler` and `max_retries` are reserved for
`_response_parse_decorator` to parse and check the response
generated by model wrapper. Their usages are listed as follows:
- `parse_func` is a callable function used to parse and check
the response generated by the model, which takes the response
as input.
- `max_retries` is the maximum number of retries when the
`parse_func` raise an exception.
- `fault_handler` is a callable function which is called
when the response generated by the model is invalid after
`max_retries` retries.
"""

# step1: prepare keyword arguments
kwargs = {**self.generate_args, **kwargs}

# step2: checking messages
if not all("role" in msg and "content" in msg for msg in messages):
raise ValueError(
"Each message in the 'messages' list must contain a 'role' "
"and 'content' key for Tongyi API.",
)

# For Tongyi model, the "role" value of the first and the last message
# must be "user"
if len(messages) > 0:
messages[0]["role"] = "user"
messages[-1]["role"] = "user"

# step3: forward to generate response
response = dashscope.Generation.call(
model=self.model,
messages=messages,
result_format="message", # set the result to be "message" format.
**kwargs,
)

# step4: record the api invocation if needed
self._save_model_invocation(
arguments={
"model": self.model,
"messages": messages,
**kwargs,
},
json_response=response,
)

# step5: update monitor accordingly
try:
self.monitor.update(
response.usage,
prefix=self.model,
)
except QuotaExceededError as e:
# TODO: optimize quota exceeded error handling process
logger.error(e.message)

# step6: return response
return ModelResponse(
text=response.output["choices"][0]["message"]["content"],
raw=response,
)
2 changes: 1 addition & 1 deletion tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_load_model_configs(self) -> None:
{
"model_type": "openai",
"config_name": "gpt-4",
"model": "gpt-4",
"model_name": "gpt-4",
pan-x-c marked this conversation as resolved.
Show resolved Hide resolved
"api_key": "xxx",
"organization": "xxx",
"generate_args": {"temperature": 0.5},
Expand Down
Loading