Skip to content

Commit

Permalink
Integrate Yi Chat wrapper into AgentScope (#343)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: DavdGao <[email protected]>
  • Loading branch information
Haijian06 and DavdGao authored Aug 28, 2024
1 parent 79bf906 commit f40aaf4
Show file tree
Hide file tree
Showing 8 changed files with 353 additions and 43 deletions.
37 changes: 19 additions & 18 deletions README.md

Large diffs are not rendered by default.

37 changes: 19 additions & 18 deletions README_ZH.md

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions docs/sphinx_doc/en/source/tutorial/203-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ In the current AgentScope, the supported `model_type` types, the corresponding

| API | Task | Model Wrapper | `model_type` | Some Supported Models |
|------------------------|-----------------|---------------------------------------------------------------------------------------------------------------------------------|-------------------------------|--------------------------------------------------|
| OpenAI API | Chat | [`OpenAIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py) | `"openai_chat"` | gpt-4, gpt-3.5-turbo, ... |
| OpenAI API | Chat | [`OpenAIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py) | `"openai_chat"` | gpt-4, gpt-3.5-turbo, ... |
| | Embedding | [`OpenAIEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py) | `"openai_embedding"` | text-embedding-ada-002, ... |
| | DALL·E | [`OpenAIDALLEWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/openai_model.py) | `"openai_dall_e"` | dall-e-2, dall-e-3 |
| DashScope API | Chat | [`DashScopeChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/dashscope_model.py) | `"dashscope_chat"` | qwen-plus, qwen-max, ... |
Expand All @@ -83,12 +83,13 @@ In the current AgentScope, the supported `model_type` types, the corresponding
| | Multimodal | [`DashScopeMultiModalWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/dashscope_model.py) | `"dashscope_multimodal"` | qwen-vl-plus, qwen-vl-max, qwen-audio-turbo, ... |
| Gemini API | Chat | [`GeminiChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/gemini_model.py) | `"gemini_chat"` | gemini-pro, ... |
| | Embedding | [`GeminiEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/gemini_model.py) | `"gemini_embedding"` | models/embedding-001, ... |
| ZhipuAI API | Chat | [`ZhipuAIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/zhipu_model.py) | `"zhipuai_chat"` | glm4, ... |
| | Embedding | [`ZhipuAIEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/zhipu_model.py) | `"zhipuai_embedding"` | embedding-2, ... |
| ZhipuAI API | Chat | [`ZhipuAIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/zhipu_model.py) | `"zhipuai_chat"` | glm4, ... |
| | Embedding | [`ZhipuAIEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/zhipu_model.py) | `"zhipuai_embedding"` | embedding-2, ... |
| ollama | Chat | [`OllamaChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_chat"` | llama2, ... |
| | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_embedding"` | llama2, ... |
| | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... |
| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - |
| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - |
| Yi API | Chat | [`YiChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/yi_model.py) | `"yi_chat"` | yi-large, yi-medium, ... |
| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - |
| | Chat | [`PostAPIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... |
| | Image Synthesis | [`PostAPIDALLEWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_dall_e` | - | |
Expand Down
1 change: 1 addition & 0 deletions docs/sphinx_doc/zh_CN/source/tutorial/203-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ API如下:
| | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_embedding"` | llama2, ... |
| | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... |
| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - |
| Yi API | Chat | [`YiChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/yi_model.py) | `"yi_chat"` | yi-large, yi-medium, ... |
| Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - |
| | Chat | [`PostAPIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... |
| | Image Synthesis | [`PostAPIDALLEWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `post_api_dall_e` | - | |
Expand Down
11 changes: 11 additions & 0 deletions examples/model_configs_template/yi_chat_template.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[
{
"config_name": "yi_yi-large",
"model_type": "yi_chat",
"model_name": "yi-large",
"api_key": "{your_api_key}",
"temperature": 0.3,
"top_p": 0.9,
"max_tokens": 1000
}
]
5 changes: 4 additions & 1 deletion src/agentscope/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
from .litellm_model import (
LiteLLMChatWrapper,
)

from .yi_model import (
YiChatWrapper,
)

__all__ = [
"ModelWrapperBase",
Expand All @@ -61,6 +63,7 @@
"ZhipuAIChatWrapper",
"ZhipuAIEmbeddingWrapper",
"LiteLLMChatWrapper",
"YiChatWrapper",
]


Expand Down
4 changes: 2 additions & 2 deletions src/agentscope/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(

def __call__(
self,
messages: list,
messages: list[dict],
stream: Optional[bool] = None,
**kwargs: Any,
) -> ModelResponse:
Expand Down Expand Up @@ -331,7 +331,7 @@ def _save_model_invocation_and_update_monitor(
response=response,
)

usage = response.get("usage")
usage = response.get("usage", None)
if usage is not None:
self.monitor.update_text_and_embedding_tokens(
model_name=self.model_name,
Expand Down
292 changes: 292 additions & 0 deletions src/agentscope/models/yi_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# -*- coding: utf-8 -*-
"""Model wrapper for Yi models"""
import json
from typing import (
List,
Union,
Sequence,
Optional,
Generator,
)

import requests

from ._model_utils import (
_verify_text_content_in_openai_message_response,
_verify_text_content_in_openai_delta_response,
)
from .model import ModelWrapperBase, ModelResponse
from ..message import Msg


class YiChatWrapper(ModelWrapperBase):
"""The model wrapper for Yi Chat API.
Response:
- From https://platform.lingyiwanwu.com/docs
```json
{
"id": "cmpl-ea89ae83",
"object": "chat.completion",
"created": 5785971,
"model": "yi-large-rag",
"usage": {
"completion_tokens": 113,
"prompt_tokens": 896,
"total_tokens": 1009
},
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Today in Los Angeles, the weather ...",
},
"finish_reason": "stop"
}
]
}
```
"""

model_type: str = "yi_chat"

def __init__(
self,
config_name: str,
model_name: str,
api_key: str,
max_tokens: Optional[int] = None,
top_p: float = 0.9,
temperature: float = 0.3,
stream: bool = False,
) -> None:
"""Initialize the Yi chat model wrapper.
Args:
config_name (`str`):
The name of the configuration to use.
model_name (`str`):
The name of the model to use, e.g. yi-large, yi-medium, etc.
api_key (`str`):
The API key for the Yi API.
max_tokens (`Optional[int]`, defaults to `None`):
The maximum number of tokens to generate, defaults to `None`.
top_p (`float`, defaults to `0.9`):
The randomness parameters in the range [0, 1].
temperature (`float`, defaults to `0.3`):
The temperature parameter in the range [0, 2].
stream (`bool`, defaults to `False`):
Whether to stream the response or not.
"""

super().__init__(config_name, model_name)

if top_p > 1 or top_p < 0:
raise ValueError(
f"The `top_p` parameter must be in the range [0, 1], but got "
f"{top_p} instead.",
)

if temperature < 0 or temperature > 2:
raise ValueError(
f"The `temperature` parameter must be in the range [0, 2], "
f"but got {temperature} instead.",
)

self.api_key = api_key
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.stream = stream

def __call__(
self,
messages: list[dict],
stream: Optional[bool] = None,
) -> ModelResponse:
"""Invoke the Yi Chat API by sending a list of messages."""

# Checking messages
if not isinstance(messages, list):
raise ValueError(
f"Yi `messages` field expected type `list`, "
f"got `{type(messages)}` instead.",
)

if not all("role" in msg and "content" in msg for msg in messages):
raise ValueError(
"Each message in the 'messages' list must contain a 'role' "
"and 'content' key for Yi API.",
)

if stream is None:
stream = self.stream

# Forward to generate response
kwargs = {
"url": "https://api.lingyiwanwu.com/v1/chat/completions",
"json": {
"model": self.model_name,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"stream": stream,
},
"headers": {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
}

response = requests.post(**kwargs)
response.raise_for_status()

if stream:

def generator() -> Generator[str, None, None]:
text = ""
last_chunk = {}
for line in response.iter_lines():
if line:
line_str = line.decode("utf-8").strip()

# Remove prefix "data: " if exists
json_str = line_str.removeprefix("data: ")

# The last response is "data: [DONE]"
if json_str == "[DONE]":
continue

try:
chunk = json.loads(json_str)
if _verify_text_content_in_openai_delta_response(
chunk,
):
text += chunk["choices"][0]["delta"]["content"]
yield text
last_chunk = chunk

except json.decoder.JSONDecodeError as e:
raise json.decoder.JSONDecodeError(
f"Invalid JSON: {json_str}",
e.doc,
e.pos,
) from e

# In Yi Chat API, the last valid chunk will save all the text
# in this message
self._save_model_invocation_and_update_monitor(
kwargs,
last_chunk,
)

return ModelResponse(
stream=generator(),
)
else:
response = response.json()
self._save_model_invocation_and_update_monitor(
kwargs,
response,
)

# Re-use the openai response checking function
if _verify_text_content_in_openai_message_response(response):
return ModelResponse(
text=response["choices"][0]["message"]["content"],
raw=response,
)
else:
raise RuntimeError(
f"Invalid response from Yi Chat API: {response}",
)

def format(
self,
*args: Union[Msg, Sequence[Msg]],
) -> List[dict]:
"""Format the messages into the required format of Yi Chat API.
Note this strategy maybe not suitable for all scenarios,
and developers are encouraged to implement their own prompt
engineering strategies.
The following is an example:
.. code-block:: python
prompt1 = model.format(
Msg("system", "You're a helpful assistant", role="system"),
Msg("Bob", "Hi, how can I help you?", role="assistant"),
Msg("user", "What's the date today?", role="user")
)
The prompt will be as follows:
.. code-block:: python
# prompt1
[
{
"role": "user",
"content": (
"You're a helpful assistant\\n"
"\\n"
"## Conversation History\\n"
"Bob: Hi, how can I help you?\\n"
"user: What's the date today?"
)
}
]
Args:
args (`Union[Msg, Sequence[Msg]]`):
The input arguments to be formatted, where each argument
should be a `Msg` object, or a list of `Msg` objects.
In distribution, placeholder is also allowed.
Returns:
`List[dict]`:
The formatted messages.
"""

# TODO: Support Vision model
if self.model_name == "yi-vision":
raise NotImplementedError(
"Yi Vision model is not supported in the current version, "
"please format the messages manually.",
)

return ModelWrapperBase.format_for_common_chat_models(*args)

def _save_model_invocation_and_update_monitor(
self,
kwargs: dict,
response: dict,
) -> None:
"""Save model invocation and update the monitor accordingly.
Args:
kwargs (`dict`):
The keyword arguments used in model invocation
response (`dict`):
The response from model API
"""
self._save_model_invocation(
arguments=kwargs,
response=response,
)

usage = response.get("usage", None)
if usage is not None:
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)

self.monitor.update_text_and_embedding_tokens(
model_name=self.model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)

0 comments on commit f40aaf4

Please sign in to comment.