Skip to content

Commit

Permalink
add post api wrapper for tongyi
Browse files Browse the repository at this point in the history
  • Loading branch information
qbc2016 committed Mar 4, 2024
1 parent 9449114 commit 91250d0
Showing 1 changed file with 172 additions and 0 deletions.
172 changes: 172 additions & 0 deletions src/agentscope/models/post_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,175 @@ class PostAPIDALLEWrapper(PostAPIModelWrapperBase):
def _parse_response(self, response: dict) -> ModelResponse:
urls = [img["url"] for img in response["data"]["response"]["data"]]
return ModelResponse(image_urls=urls)


class PostAPITongyiWrapper(ModelWrapperBase):
"""The base model wrapper for Tongyi"""

model_type: str = "post_api_tongyi"

def __init__(
self,
config_name: str,
api_url: str,
headers: dict = None,
max_length: int = 2048,
timeout: int = 30,
json_args: dict = None,
post_args: dict = None,
max_retries: int = _DEFAULT_MAX_RETRIES,
messages_key: str = _DEFAULT_MESSAGES_KEY,
retry_interval: int = _DEFAULT_RETRY_INTERVAL,
**kwargs: Any,
) -> None:
"""Initialize the Tongyi model wrapper.
Args:
config_name (`str`):
The id of the model.
api_url (`str`):
The url of the post request api.
headers (`dict`, defaults to `None`):
The headers of the api. Defaults to None.
max_length (`int`, defaults to `2048`):
The maximum length of the model.
timeout (`int`, defaults to `30`):
The timeout of the api. Defaults to 30.
json_args (`dict`, defaults to `None`):
The json arguments of the api. Defaults to None.
post_args (`dict`, defaults to `None`):
The post arguments of the api. Defaults to None.
max_retries (`int`, defaults to `3`):
The maximum number of retries when the `parse_func` raise an
exception.
messages_key (`str`, defaults to `inputs`):
The key of the input messages in the json argument.
retry_interval (`int`, defaults to `1`):
The interval between retries when a request fails.
Note:
When an object of `PostApiModelWrapper` is called, the arguments
will of post requests will be used as follows:
.. code-block:: python
request.post(
url=api_url,
headers=headers,
json={
messages_key: messages,
**json_args
},
**post_args
)
"""
super().__init__(
config_name=config_name,
api_url=api_url,
headers=headers,
max_length=max_length,
timeout=timeout,
json_args=json_args,
post_args=post_args,
max_retries=max_retries,
messages_key=messages_key,
retry_interval=retry_interval,
**kwargs,
)
self.api_url = api_url
self.headers = headers
self.max_length = max_length
self.timeout = timeout
self.json_args = json_args or {}
self.post_args = post_args or {}
self.max_retries = max_retries
self.messages_key = messages_key
self.retry_interval = retry_interval

def _parse_response(self, response: dict) -> ModelResponse:
"""Parse the response json data into ModelResponse"""
return ModelResponse(raw=response)

def __call__(self, input_: dict, **kwargs: Any) -> ModelResponse:
"""Calling the Tongyi model with requests.post.
Args:
input_ (`str`):
The input string to the model.
Returns:
`dict`: A dictionary that contains the response of the model and
related
information (e.g. cost, time, the number of tokens, etc.).
Note:
`parse_func`, `fault_handler` and `max_retries` are reserved for
`_response_parse_decorator` to parse and check the response
generated by model wrapper. Their usages are listed as follows:
- `parse_func` is a callable function used to parse and check
the response generated by the model, which takes the response
as input.
- `max_retries` is the maximum number of retries when the
`parse_func` raise an exception.
- `fault_handler` is a callable function which is called
when the response generated by the model is invalid after
`max_retries` retries.
"""
# step1: prepare keyword arguments
post_args = {**self.post_args, **kwargs}

input_[0]["role"] = "user"
input_[-1]["role"] = "user"
input_ = {"messages": input_}

request_kwargs = {
"url": self.api_url,
"json": {self.messages_key: input_, **self.json_args},
"headers": self.headers or {},
**post_args,
}

# step2: prepare post requests
for i in range(1, self.max_retries + 1):
response = requests.post(**request_kwargs)

if response.status_code == requests.codes.ok:
break

if i < self.max_retries:
# av
logger.warning(
f"Failed to call the model with "
f"requests.codes == {response.status_code}, retry "
f"{i + 1}/{self.max_retries} times",
)
time.sleep(i * self.retry_interval)

# step3: record model invocation
# record the model api invocation, which will be skipped if
# `FileManager.save_api_invocation` is `False`
self._save_model_invocation(
arguments=request_kwargs,
json_response=response.json(),
)

# step4: parse the response
if response.status_code == requests.codes.ok:
return self._parse_response(response.json())
else:
logger.error(json.dumps(request_kwargs, indent=4))
raise RuntimeError(
f"Failed to call the model with "
f"requests.codes == {response.status_code}",
)


class PostAPITongyiChatWrapper(PostAPITongyiWrapper):
"""A post api model wrapper compatible with Tongyi chat"""

model_type: str = "post_api_tongyi_chat"

def _parse_response(self, response: dict) -> ModelResponse:
return ModelResponse(
text=response["output"]["text"],
)

0 comments on commit 91250d0

Please sign in to comment.