-
Notifications
You must be signed in to change notification settings - Fork 352
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate Yi Chat wrapper into AgentScope (#343)
--------- Co-authored-by: DavdGao <[email protected]>
- Loading branch information
Showing
8 changed files
with
353 additions
and
43 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
[ | ||
{ | ||
"config_name": "yi_yi-large", | ||
"model_type": "yi_chat", | ||
"model_name": "yi-large", | ||
"api_key": "{your_api_key}", | ||
"temperature": 0.3, | ||
"top_p": 0.9, | ||
"max_tokens": 1000 | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,292 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Model wrapper for Yi models""" | ||
import json | ||
from typing import ( | ||
List, | ||
Union, | ||
Sequence, | ||
Optional, | ||
Generator, | ||
) | ||
|
||
import requests | ||
|
||
from ._model_utils import ( | ||
_verify_text_content_in_openai_message_response, | ||
_verify_text_content_in_openai_delta_response, | ||
) | ||
from .model import ModelWrapperBase, ModelResponse | ||
from ..message import Msg | ||
|
||
|
||
class YiChatWrapper(ModelWrapperBase): | ||
"""The model wrapper for Yi Chat API. | ||
Response: | ||
- From https://platform.lingyiwanwu.com/docs | ||
```json | ||
{ | ||
"id": "cmpl-ea89ae83", | ||
"object": "chat.completion", | ||
"created": 5785971, | ||
"model": "yi-large-rag", | ||
"usage": { | ||
"completion_tokens": 113, | ||
"prompt_tokens": 896, | ||
"total_tokens": 1009 | ||
}, | ||
"choices": [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": "Today in Los Angeles, the weather ...", | ||
}, | ||
"finish_reason": "stop" | ||
} | ||
] | ||
} | ||
``` | ||
""" | ||
|
||
model_type: str = "yi_chat" | ||
|
||
def __init__( | ||
self, | ||
config_name: str, | ||
model_name: str, | ||
api_key: str, | ||
max_tokens: Optional[int] = None, | ||
top_p: float = 0.9, | ||
temperature: float = 0.3, | ||
stream: bool = False, | ||
) -> None: | ||
"""Initialize the Yi chat model wrapper. | ||
Args: | ||
config_name (`str`): | ||
The name of the configuration to use. | ||
model_name (`str`): | ||
The name of the model to use, e.g. yi-large, yi-medium, etc. | ||
api_key (`str`): | ||
The API key for the Yi API. | ||
max_tokens (`Optional[int]`, defaults to `None`): | ||
The maximum number of tokens to generate, defaults to `None`. | ||
top_p (`float`, defaults to `0.9`): | ||
The randomness parameters in the range [0, 1]. | ||
temperature (`float`, defaults to `0.3`): | ||
The temperature parameter in the range [0, 2]. | ||
stream (`bool`, defaults to `False`): | ||
Whether to stream the response or not. | ||
""" | ||
|
||
super().__init__(config_name, model_name) | ||
|
||
if top_p > 1 or top_p < 0: | ||
raise ValueError( | ||
f"The `top_p` parameter must be in the range [0, 1], but got " | ||
f"{top_p} instead.", | ||
) | ||
|
||
if temperature < 0 or temperature > 2: | ||
raise ValueError( | ||
f"The `temperature` parameter must be in the range [0, 2], " | ||
f"but got {temperature} instead.", | ||
) | ||
|
||
self.api_key = api_key | ||
self.max_tokens = max_tokens | ||
self.top_p = top_p | ||
self.temperature = temperature | ||
self.stream = stream | ||
|
||
def __call__( | ||
self, | ||
messages: list[dict], | ||
stream: Optional[bool] = None, | ||
) -> ModelResponse: | ||
"""Invoke the Yi Chat API by sending a list of messages.""" | ||
|
||
# Checking messages | ||
if not isinstance(messages, list): | ||
raise ValueError( | ||
f"Yi `messages` field expected type `list`, " | ||
f"got `{type(messages)}` instead.", | ||
) | ||
|
||
if not all("role" in msg and "content" in msg for msg in messages): | ||
raise ValueError( | ||
"Each message in the 'messages' list must contain a 'role' " | ||
"and 'content' key for Yi API.", | ||
) | ||
|
||
if stream is None: | ||
stream = self.stream | ||
|
||
# Forward to generate response | ||
kwargs = { | ||
"url": "https://api.lingyiwanwu.com/v1/chat/completions", | ||
"json": { | ||
"model": self.model_name, | ||
"messages": messages, | ||
"temperature": self.temperature, | ||
"max_tokens": self.max_tokens, | ||
"top_p": self.top_p, | ||
"stream": stream, | ||
}, | ||
"headers": { | ||
"Authorization": f"Bearer {self.api_key}", | ||
"Content-Type": "application/json", | ||
}, | ||
} | ||
|
||
response = requests.post(**kwargs) | ||
response.raise_for_status() | ||
|
||
if stream: | ||
|
||
def generator() -> Generator[str, None, None]: | ||
text = "" | ||
last_chunk = {} | ||
for line in response.iter_lines(): | ||
if line: | ||
line_str = line.decode("utf-8").strip() | ||
|
||
# Remove prefix "data: " if exists | ||
json_str = line_str.removeprefix("data: ") | ||
|
||
# The last response is "data: [DONE]" | ||
if json_str == "[DONE]": | ||
continue | ||
|
||
try: | ||
chunk = json.loads(json_str) | ||
if _verify_text_content_in_openai_delta_response( | ||
chunk, | ||
): | ||
text += chunk["choices"][0]["delta"]["content"] | ||
yield text | ||
last_chunk = chunk | ||
|
||
except json.decoder.JSONDecodeError as e: | ||
raise json.decoder.JSONDecodeError( | ||
f"Invalid JSON: {json_str}", | ||
e.doc, | ||
e.pos, | ||
) from e | ||
|
||
# In Yi Chat API, the last valid chunk will save all the text | ||
# in this message | ||
self._save_model_invocation_and_update_monitor( | ||
kwargs, | ||
last_chunk, | ||
) | ||
|
||
return ModelResponse( | ||
stream=generator(), | ||
) | ||
else: | ||
response = response.json() | ||
self._save_model_invocation_and_update_monitor( | ||
kwargs, | ||
response, | ||
) | ||
|
||
# Re-use the openai response checking function | ||
if _verify_text_content_in_openai_message_response(response): | ||
return ModelResponse( | ||
text=response["choices"][0]["message"]["content"], | ||
raw=response, | ||
) | ||
else: | ||
raise RuntimeError( | ||
f"Invalid response from Yi Chat API: {response}", | ||
) | ||
|
||
def format( | ||
self, | ||
*args: Union[Msg, Sequence[Msg]], | ||
) -> List[dict]: | ||
"""Format the messages into the required format of Yi Chat API. | ||
Note this strategy maybe not suitable for all scenarios, | ||
and developers are encouraged to implement their own prompt | ||
engineering strategies. | ||
The following is an example: | ||
.. code-block:: python | ||
prompt1 = model.format( | ||
Msg("system", "You're a helpful assistant", role="system"), | ||
Msg("Bob", "Hi, how can I help you?", role="assistant"), | ||
Msg("user", "What's the date today?", role="user") | ||
) | ||
The prompt will be as follows: | ||
.. code-block:: python | ||
# prompt1 | ||
[ | ||
{ | ||
"role": "user", | ||
"content": ( | ||
"You're a helpful assistant\\n" | ||
"\\n" | ||
"## Conversation History\\n" | ||
"Bob: Hi, how can I help you?\\n" | ||
"user: What's the date today?" | ||
) | ||
} | ||
] | ||
Args: | ||
args (`Union[Msg, Sequence[Msg]]`): | ||
The input arguments to be formatted, where each argument | ||
should be a `Msg` object, or a list of `Msg` objects. | ||
In distribution, placeholder is also allowed. | ||
Returns: | ||
`List[dict]`: | ||
The formatted messages. | ||
""" | ||
|
||
# TODO: Support Vision model | ||
if self.model_name == "yi-vision": | ||
raise NotImplementedError( | ||
"Yi Vision model is not supported in the current version, " | ||
"please format the messages manually.", | ||
) | ||
|
||
return ModelWrapperBase.format_for_common_chat_models(*args) | ||
|
||
def _save_model_invocation_and_update_monitor( | ||
self, | ||
kwargs: dict, | ||
response: dict, | ||
) -> None: | ||
"""Save model invocation and update the monitor accordingly. | ||
Args: | ||
kwargs (`dict`): | ||
The keyword arguments used in model invocation | ||
response (`dict`): | ||
The response from model API | ||
""" | ||
self._save_model_invocation( | ||
arguments=kwargs, | ||
response=response, | ||
) | ||
|
||
usage = response.get("usage", None) | ||
if usage is not None: | ||
prompt_tokens = usage.get("prompt_tokens", 0) | ||
completion_tokens = usage.get("completion_tokens", 0) | ||
|
||
self.monitor.update_text_and_embedding_tokens( | ||
model_name=self.model_name, | ||
prompt_tokens=prompt_tokens, | ||
completion_tokens=completion_tokens, | ||
) |