From 91250d09708f4adfc26e86554cfb2debbb80b83b Mon Sep 17 00:00:00 2001 From: qbc Date: Mon, 4 Mar 2024 15:02:07 +0800 Subject: [PATCH] add post api wrapper for tongyi --- src/agentscope/models/post_model.py | 172 ++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/src/agentscope/models/post_model.py b/src/agentscope/models/post_model.py index 0addce7c6..19feab44e 100644 --- a/src/agentscope/models/post_model.py +++ b/src/agentscope/models/post_model.py @@ -192,3 +192,175 @@ class PostAPIDALLEWrapper(PostAPIModelWrapperBase): def _parse_response(self, response: dict) -> ModelResponse: urls = [img["url"] for img in response["data"]["response"]["data"]] return ModelResponse(image_urls=urls) + + +class PostAPITongyiWrapper(ModelWrapperBase): + """The base model wrapper for Tongyi""" + + model_type: str = "post_api_tongyi" + + def __init__( + self, + config_name: str, + api_url: str, + headers: dict = None, + max_length: int = 2048, + timeout: int = 30, + json_args: dict = None, + post_args: dict = None, + max_retries: int = _DEFAULT_MAX_RETRIES, + messages_key: str = _DEFAULT_MESSAGES_KEY, + retry_interval: int = _DEFAULT_RETRY_INTERVAL, + **kwargs: Any, + ) -> None: + """Initialize the Tongyi model wrapper. + + Args: + config_name (`str`): + The id of the model. + api_url (`str`): + The url of the post request api. + headers (`dict`, defaults to `None`): + The headers of the api. Defaults to None. + max_length (`int`, defaults to `2048`): + The maximum length of the model. + timeout (`int`, defaults to `30`): + The timeout of the api. Defaults to 30. + json_args (`dict`, defaults to `None`): + The json arguments of the api. Defaults to None. + post_args (`dict`, defaults to `None`): + The post arguments of the api. Defaults to None. + max_retries (`int`, defaults to `3`): + The maximum number of retries when the `parse_func` raise an + exception. + messages_key (`str`, defaults to `inputs`): + The key of the input messages in the json argument. + retry_interval (`int`, defaults to `1`): + The interval between retries when a request fails. + + Note: + When an object of `PostApiModelWrapper` is called, the arguments + will of post requests will be used as follows: + + .. code-block:: python + + request.post( + url=api_url, + headers=headers, + json={ + messages_key: messages, + **json_args + }, + **post_args + ) + """ + super().__init__( + config_name=config_name, + api_url=api_url, + headers=headers, + max_length=max_length, + timeout=timeout, + json_args=json_args, + post_args=post_args, + max_retries=max_retries, + messages_key=messages_key, + retry_interval=retry_interval, + **kwargs, + ) + self.api_url = api_url + self.headers = headers + self.max_length = max_length + self.timeout = timeout + self.json_args = json_args or {} + self.post_args = post_args or {} + self.max_retries = max_retries + self.messages_key = messages_key + self.retry_interval = retry_interval + + def _parse_response(self, response: dict) -> ModelResponse: + """Parse the response json data into ModelResponse""" + return ModelResponse(raw=response) + + def __call__(self, input_: dict, **kwargs: Any) -> ModelResponse: + """Calling the Tongyi model with requests.post. + + Args: + input_ (`str`): + The input string to the model. + + Returns: + `dict`: A dictionary that contains the response of the model and + related + information (e.g. cost, time, the number of tokens, etc.). + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved for + `_response_parse_decorator` to parse and check the response + generated by model wrapper. Their usages are listed as follows: + - `parse_func` is a callable function used to parse and check + the response generated by the model, which takes the response + as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + """ + # step1: prepare keyword arguments + post_args = {**self.post_args, **kwargs} + + input_[0]["role"] = "user" + input_[-1]["role"] = "user" + input_ = {"messages": input_} + + request_kwargs = { + "url": self.api_url, + "json": {self.messages_key: input_, **self.json_args}, + "headers": self.headers or {}, + **post_args, + } + + # step2: prepare post requests + for i in range(1, self.max_retries + 1): + response = requests.post(**request_kwargs) + + if response.status_code == requests.codes.ok: + break + + if i < self.max_retries: + # av + logger.warning( + f"Failed to call the model with " + f"requests.codes == {response.status_code}, retry " + f"{i + 1}/{self.max_retries} times", + ) + time.sleep(i * self.retry_interval) + + # step3: record model invocation + # record the model api invocation, which will be skipped if + # `FileManager.save_api_invocation` is `False` + self._save_model_invocation( + arguments=request_kwargs, + json_response=response.json(), + ) + + # step4: parse the response + if response.status_code == requests.codes.ok: + return self._parse_response(response.json()) + else: + logger.error(json.dumps(request_kwargs, indent=4)) + raise RuntimeError( + f"Failed to call the model with " + f"requests.codes == {response.status_code}", + ) + + +class PostAPITongyiChatWrapper(PostAPITongyiWrapper): + """A post api model wrapper compatible with Tongyi chat""" + + model_type: str = "post_api_tongyi_chat" + + def _parse_response(self, response: dict) -> ModelResponse: + return ModelResponse( + text=response["output"]["text"], + )