From fec2543022403c1f5c8d3c1cb21a10c5aca17aae Mon Sep 17 00:00:00 2001 From: DavdGao Date: Sun, 4 Feb 2024 15:13:35 +0800 Subject: [PATCH] 1. format correction; 2. modify unit test for web search --- src/agentscope/pipelines/functional.py | 2 +- src/agentscope/service/__init__.py | 2 +- src/agentscope/service/service_factory.py | 85 +++-- src/agentscope/service/web_search/search.py | 20 +- tests/service_factory_test.py | 394 +++++++++++--------- tests/web_search_test.py | 11 +- 6 files changed, 293 insertions(+), 221 deletions(-) diff --git a/src/agentscope/pipelines/functional.py b/src/agentscope/pipelines/functional.py index 2cb09b608..d04a4596c 100644 --- a/src/agentscope/pipelines/functional.py +++ b/src/agentscope/pipelines/functional.py @@ -6,7 +6,7 @@ Optional, Union, Any, - Mapping + Mapping, ) from ..agents.operator import Operator diff --git a/src/agentscope/service/__init__.py b/src/agentscope/service/__init__.py index 836032a19..e5a07b7e8 100644 --- a/src/agentscope/service/__init__.py +++ b/src/agentscope/service/__init__.py @@ -43,7 +43,7 @@ def get_help() -> None: "read_json_file", "write_json_file", "bing_search", - "google_search" + "google_search", "query_mysql", "query_sqlite", "query_mongodb", diff --git a/src/agentscope/service/service_factory.py b/src/agentscope/service/service_factory.py index cd81f9347..b1b9e9d03 100644 --- a/src/agentscope/service/service_factory.py +++ b/src/agentscope/service/service_factory.py @@ -8,42 +8,43 @@ Any, Tuple, Union, + Optional, Literal, get_args, - get_origin + get_origin, ) from docstring_parser import parse from loguru import logger -from agentscope.service import bing_search - -def _get_type_str(cls): +def _get_type_str(cls: Any) -> Optional[Union[str, list]]: """Get the type string.""" + type_str = None if hasattr(cls, "__origin__"): # Typing class if cls.__origin__ is Union: - return [_get_type_str(_) for _ in get_args(cls)] + type_str = [_get_type_str(_) for _ in get_args(cls)] elif cls.__origin__ is collections.abc.Sequence: - return "array" + type_str = "array" else: - return str(cls.__origin__) + type_str = str(cls.__origin__) else: # Normal class if cls is str: - return "string" + type_str = "string" elif cls in [float, int, complex]: - return "number" + type_str = "number" elif cls is bool: - return "boolean" + type_str = "boolean" elif cls is collections.abc.Sequence: - return "array" + type_str = "array" elif cls is None.__class__: - return "null" + type_str = "null" else: - return cls.__name__ + type_str = cls.__name__ + return type_str # type: ignore[return-value] class ServiceFactory: @@ -51,8 +52,11 @@ class ServiceFactory: prompt format.""" @classmethod - def get(self, service_func: Callable[..., Any], **kwargs: Any) -> Tuple[ - Callable[..., Any], dict]: + def get( + cls, + service_func: Callable[..., Any], + **kwargs: Any, + ) -> Tuple[Callable[..., Any], dict]: """Covnert a service function into a tool function that agent can use, and generate a dictionary in JSON Schema format that can be used in OpenAI API directly. While for open-source model, developers @@ -94,52 +98,61 @@ def get(self, service_func: Callable[..., Any], **kwargs: Any) -> Tuple[ docstring = parse(service_func.__doc__) # Function description - func_description = (docstring.short_description or - docstring.long_description) + func_description = ( + docstring.short_description or docstring.long_description + ) # The arguments that requires the agent to specify args_agent = set(argsspec.args) - set(kwargs.keys()) # Check if the arguments from agent have descriptions in docstring - args_description = {_.arg_name: _.description for _ in - docstring.params} + args_description = { + _.arg_name: _.description for _ in docstring.params + } # Prepare default values - args_defaults = {k: v for k, v in zip(reversed(argsspec.args), - reversed(argsspec.defaults))} + args_defaults = dict( + zip( + reversed(argsspec.args), + reversed(argsspec.defaults), # type: ignore + ), + ) + args_required = list(set(args_agent) - set(args_defaults.keys())) # Prepare types of the arguments, remove the return type - args_types = {k: v for k, v in argsspec.annotations.items() if k != - "return"} + args_types = { + k: v for k, v in argsspec.annotations.items() if k != "return" + } # Prepare argument dictionary - properties_field = dict() + properties_field = {} for key in args_agent: - property = dict() + arg_property = {} # type if key in args_types: try: required_type = _get_type_str(args_types[key]) - property["type"] = required_type + arg_property["type"] = required_type except Exception: - logger.warning(f"Fail and skip to get the type of the " - f"argument `{key}`.") - + logger.warning( + f"Fail and skip to get the type of the " + f"argument `{key}`.", + ) # For Literal type, add enum field if get_origin(args_types[key]) is Literal: - property["enum"] = list(args_types[key].__args__) + arg_property["enum"] = list(args_types[key].__args__) # description if key in args_description: - property["description"] = args_description[key] + arg_property["description"] = args_description[key] # default if key in args_defaults and args_defaults[key] is not None: - property["default"] = args_defaults[key] + arg_property["default"] = args_defaults[key] - properties_field[key] = property + properties_field[key] = arg_property # Construct the JSON Schema for the service function func_dict = { @@ -150,9 +163,9 @@ def get(self, service_func: Callable[..., Any], **kwargs: Any) -> Tuple[ "parameters": { "type": "object", "properties": properties_field, - "required": args_required - } - } + "required": args_required, + }, + }, } return tool_func, func_dict diff --git a/src/agentscope/service/web_search/search.py b/src/agentscope/service/web_search/search.py index 828041d01..9d26837db 100644 --- a/src/agentscope/service/web_search/search.py +++ b/src/agentscope/service/web_search/search.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """Search question in the web""" -from typing import Optional, Any +from typing import Any from agentscope.service.service_response import ServiceResponse from agentscope.utils.common import requests_get @@ -9,7 +9,7 @@ def bing_search( question: str, - bing_api_key: str, + api_key: str, num_results: int = 10, **kwargs: Any, ) -> ServiceResponse: @@ -19,7 +19,7 @@ def bing_search( Args: question (`str`): The search query string. - bing_api_key (`str`): + api_key (`str`): The API key provided for authenticating with the Bing Search API. num_results (`int`, defaults to `10`): The number of search results to return. @@ -84,7 +84,7 @@ def bing_search( if kwargs: params.update(**kwargs) - headers = {"Ocp-Apim-Subscription-Key": bing_api_key} + headers = {"Ocp-Apim-Subscription-Key": api_key} search_results = requests_get( bing_search_url, @@ -116,8 +116,8 @@ def bing_search( def google_search( question: str, - google_api_key: str, - google_cse_id: str, + api_key: str, + cse_id: str, num_results: int = 10, **kwargs: Any, ) -> ServiceResponse: @@ -127,10 +127,10 @@ def google_search( Args: question (`str`): The search query string. - google_api_key (`str`): + api_key (`str`): The API key provided for authenticating with the Google Custom Search JSON API. - google_cse_id (`str`): + cse_id (`str`): The unique identifier of a programmable search engine to use. num_results (`int`, defaults to `10`): The number of search results to return. @@ -167,8 +167,8 @@ def google_search( # Define the query parameters params = { "q": question, - "key": google_api_key, - "cx": google_cse_id, + "key": api_key, + "cx": cse_id, "num": num_results, } if kwargs: diff --git a/tests/service_factory_test.py b/tests/service_factory_test.py index 231f0a83a..9df195de9 100644 --- a/tests/service_factory_test.py +++ b/tests/service_factory_test.py @@ -5,8 +5,13 @@ from typing import Literal from agentscope.models import ModelWrapperBase -from agentscope.service import bing_search, execute_python_code, \ - retrieve_from_list, query_mysql, summarization +from agentscope.service import ( + bing_search, + execute_python_code, + retrieve_from_list, + query_mysql, + summarization, +) from agentscope.service.service_factory import ServiceFactory @@ -17,7 +22,6 @@ class ServiceFactoryTest(unittest.TestCase): def setUp(self) -> None: """Init for ExampleTest.""" - pass def test_bing_search(self) -> None: """Test bing_search.""" @@ -25,187 +29,245 @@ def test_bing_search(self) -> None: # are specified by model _, doc_dict = ServiceFactory.get(bing_search, bing_api_key="xxx") - self.assertDictEqual(doc_dict, { - "type": "function", - "function": { - "name": "bing_search", - "description": "Search question in Bing Search API and return the searching results", - "parameters": { - "type": "object", - "properties": { - "num_results": { - "type": "number", - "description": "The number of search results to return.", - "default": 10 + self.assertDictEqual( + doc_dict, + { + "type": "function", + "function": { + "name": "bing_search", + "description": ( + "Search question in Bing Search API and " + "return the searching results" + ), + "parameters": { + "type": "object", + "properties": { + "num_results": { + "type": "number", + "description": ( + "The number of search " + "results to return." + ), + "default": 10, + }, + "question": { + "type": "string", + "description": "The search query string.", + }, }, - "question": { - "type": "string", - "description": "The search query string." - } + "required": ["question"], }, - "required": ["question"] - } - } - }) + }, + }, + ) # Set num_results by developer rather than model - _, doc_dict = ServiceFactory.get(bing_search, - num_results=3, bing_api_key="xxx") - - self.assertEquals(doc_dict, { - "type": "function", - "function": { - "name": "bing_search", - "description": "Search question in Bing Search API and return the searching results", - "parameters": { - "type": "object", - "properties": { - "question": { - "type": "string", - "description": "The search query string." - } + _, doc_dict = ServiceFactory.get( + bing_search, + num_results=3, + bing_api_key="xxx", + ) + + self.assertDictEqual( + doc_dict, + { + "type": "function", + "function": { + "name": "bing_search", + "description": ( + "Search question in Bing Search API and " + "return the searching results" + ), + "parameters": { + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "The search query string.", + }, + }, + "required": ["question"], }, - "required": ["question"] - } - } - }) + }, + }, + ) - def test_enum(self): - def func(a: str, b, c="test", d: Literal[1, "abc", "d"] = 1) -> int: - pass + def test_enum(self) -> None: + """Test enum in service factory.""" + + def func( # type: ignore + a: str, + b, + c="test", + d: Literal[1, "abc", "d"] = 1, + ) -> None: + print(a, b, c, d) _, doc_dict = ServiceFactory.get(func) - self.assertDictEqual(doc_dict, { - "type": "function", - "function": { - "name": "func", - "description": None, - "parameters": { - "type": "object", - "properties": { - "c": {"default": "test"}, - "d": { - "type": "typing.Literal", - "enum": [1, "abc", "d"], - "default": 1 + self.assertDictEqual( + doc_dict, + { + "type": "function", + "function": { + "name": "func", + "description": None, + "parameters": { + "type": "object", + "properties": { + "c": {"default": "test"}, + "d": { + "type": "typing.Literal", + "enum": [1, "abc", "d"], + "default": 1, + }, + "b": {}, + "a": {"type": "string"}, }, - "b": {}, - "a": {"type": "string"} + "required": ["b", "a"], }, - "required": ["b", "a"] - } - } - }) - - def test_exec_python_code(self): - _, doc_dict = ServiceFactory.get(execute_python_code, - timeout=300, - use_docker=True, - maximum_memory_bytes=None) - - self.assertDictEqual(doc_dict, { - "type": "function", - "function": { - "name": "execute_python_code", - "description": "Execute a piece of python code.", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "The Python code to be executed." - } + }, + }, + ) + + def test_exec_python_code(self) -> None: + """Test execute_python_code in service factory.""" + _, doc_dict = ServiceFactory.get( + execute_python_code, + timeout=300, + use_docker=True, + maximum_memory_bytes=None, + ) + + self.assertDictEqual( + doc_dict, + { + "type": "function", + "function": { + "name": "execute_python_code", + "description": "Execute a piece of python code.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": ( + "The Python code to be " "executed." + ), + }, + }, + "required": ["code"], }, - "required": ["code"] - } - } - }) - - def test_retrieval(self): - _, doc_dict = ServiceFactory.get(retrieve_from_list, - knowledge=[1, 2, 3], - score_func=lambda x, y: 1.0, - top_k=10, - embedding_model=10, - preserve_order=True) - - self.assertDictEqual(doc_dict, { - "type": "function", - "function": { - "name": "retrieve_from_list", - "description": "Retrieve data in a list.", - "parameters": { - "type": "object", - "properties": { - "query": { - "description": "A message to be retrieved." - } + }, + }, + ) + + def test_retrieval(self) -> None: + """Test retrieval in service factory.""" + _, doc_dict = ServiceFactory.get( + retrieve_from_list, + knowledge=[1, 2, 3], + score_func=lambda x, y: 1.0, + top_k=10, + embedding_model=10, + preserve_order=True, + ) + + self.assertDictEqual( + doc_dict, + { + "type": "function", + "function": { + "name": "retrieve_from_list", + "description": "Retrieve data in a list.", + "parameters": { + "type": "object", + "properties": { + "query": { + "description": "A message to be retrieved.", + }, + }, + "required": [ + "query", + ], }, - "required": [ - "query" - ] - } - } - }) - - def test_sql_query(self): - _, doc_dict = ServiceFactory.get(query_mysql, database="test", - host="localhost", - user="root", password="xxx", - port=3306, - allow_change_data=False, - maxcount_results=None) - - self.assertDictEqual(doc_dict, { - "type": "function", - "function": { - "name": "query_mysql", - "description": "Execute query within MySQL database.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "SQL query to execute." - } + }, + }, + ) + + def test_sql_query(self) -> None: + """Test sql_query in service factory.""" + _, doc_dict = ServiceFactory.get( + query_mysql, + database="test", + host="localhost", + user="root", + password="xxx", + port=3306, + allow_change_data=False, + maxcount_results=None, + ) + + self.assertDictEqual( + doc_dict, + { + "type": "function", + "function": { + "name": "query_mysql", + "description": "Execute query within MySQL database.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "SQL query to execute.", + }, + }, + "required": [ + "query", + ], }, - "required": [ - "query" - ] - } - } - }) - - def test_summary(self): - _, doc_dict = ServiceFactory.get(summarization, - model=ModelWrapperBase("abc"), - system_prompt="", - summarization_prompt="", - max_return_token=-1, - token_limit_prompt="") + }, + }, + ) + + def test_summary(self) -> None: + """Test summarization in service factory.""" + _, doc_dict = ServiceFactory.get( + summarization, + model=ModelWrapperBase("abc"), + system_prompt="", + summarization_prompt="", + max_return_token=-1, + token_limit_prompt="", + ) print(json.dumps(doc_dict, indent=4)) - self.assertDictEqual(doc_dict, { - "type": "function", - "function": { - "name": "summarization", - "description": "Summarize the input text.", - "parameters": { - "type": "object", - "properties": { - "text": { - "type": "string", - "description": "Text to be summarized by the model." - } + self.assertDictEqual( + doc_dict, + { + "type": "function", + "function": { + "name": "summarization", + "description": "Summarize the input text.", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": ( + "Text to be summarized by " "the model." + ), + }, + }, + "required": [ + "text", + ], }, - "required": [ - "text" - ] - } - } - }) + }, + }, + ) if __name__ == "__main__": diff --git a/tests/web_search_test.py b/tests/web_search_test.py index b5e3ca1e0..6f7e38860 100644 --- a/tests/web_search_test.py +++ b/tests/web_search_test.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, patch, MagicMock from agentscope.service import ServiceResponse -from agentscope.service import web_search +from agentscope.service import bing_search, google_search from agentscope.service.service_status import ServiceExecStatus @@ -43,7 +43,6 @@ def test_search_bing(self, mock_get: MagicMock) -> None: mock_get.return_value = mock_response # set parameters - engine = "Bing" bing_api_key = "fake-bing-api-key" test_question = "test test_question" num_results = 1 @@ -51,8 +50,7 @@ def test_search_bing(self, mock_get: MagicMock) -> None: headers = {"Ocp-Apim-Subscription-Key": bing_api_key} # Call the function - results = web_search( - engine, + results = bing_search( test_question, api_key=bing_api_key, num_results=num_results, @@ -99,7 +97,6 @@ def test_search_google(self, mock_get: MagicMock) -> None: mock_get.return_value = mock_response # set parameter - engine = "Google" test_question = "test test_question" google_api_key = "fake-google-api-key" google_cse_id = "fake-google-cse-id" @@ -113,13 +110,13 @@ def test_search_google(self, mock_get: MagicMock) -> None: } # Call the function - results = web_search( - engine, + results = google_search( test_question, api_key=google_api_key, cse_id=google_cse_id, num_results=num_results, ) + # Assertions mock_get.assert_called_once_with( "https://www.googleapis.com/customsearch/v1",