Skip to content

Commit

Permalink
1. format correction; 2. modify unit test for web search
Browse files Browse the repository at this point in the history
  • Loading branch information
DavdGao committed Feb 4, 2024
1 parent 307f5fa commit fec2543
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 221 deletions.
2 changes: 1 addition & 1 deletion src/agentscope/pipelines/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Optional,
Union,
Any,
Mapping
Mapping,
)
from ..agents.operator import Operator

Expand Down
2 changes: 1 addition & 1 deletion src/agentscope/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_help() -> None:
"read_json_file",
"write_json_file",
"bing_search",
"google_search"
"google_search",
"query_mysql",
"query_sqlite",
"query_mongodb",
Expand Down
85 changes: 49 additions & 36 deletions src/agentscope/service/service_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,51 +8,55 @@
Any,
Tuple,
Union,
Optional,
Literal,
get_args,
get_origin
get_origin,
)

from docstring_parser import parse
from loguru import logger

from agentscope.service import bing_search


def _get_type_str(cls):
def _get_type_str(cls: Any) -> Optional[Union[str, list]]:
"""Get the type string."""
type_str = None
if hasattr(cls, "__origin__"):
# Typing class
if cls.__origin__ is Union:
return [_get_type_str(_) for _ in get_args(cls)]
type_str = [_get_type_str(_) for _ in get_args(cls)]
elif cls.__origin__ is collections.abc.Sequence:
return "array"
type_str = "array"
else:
return str(cls.__origin__)
type_str = str(cls.__origin__)
else:
# Normal class
if cls is str:
return "string"
type_str = "string"
elif cls in [float, int, complex]:
return "number"
type_str = "number"
elif cls is bool:
return "boolean"
type_str = "boolean"
elif cls is collections.abc.Sequence:
return "array"
type_str = "array"
elif cls is None.__class__:
return "null"
type_str = "null"
else:
return cls.__name__
type_str = cls.__name__

return type_str # type: ignore[return-value]


class ServiceFactory:
"""A service factory class that turns service function into string
prompt format."""

@classmethod
def get(self, service_func: Callable[..., Any], **kwargs: Any) -> Tuple[
Callable[..., Any], dict]:
def get(
cls,
service_func: Callable[..., Any],
**kwargs: Any,
) -> Tuple[Callable[..., Any], dict]:
"""Covnert a service function into a tool function that agent can
use, and generate a dictionary in JSON Schema format that can be
used in OpenAI API directly. While for open-source model, developers
Expand Down Expand Up @@ -94,52 +98,61 @@ def get(self, service_func: Callable[..., Any], **kwargs: Any) -> Tuple[
docstring = parse(service_func.__doc__)

# Function description
func_description = (docstring.short_description or
docstring.long_description)
func_description = (
docstring.short_description or docstring.long_description
)

# The arguments that requires the agent to specify
args_agent = set(argsspec.args) - set(kwargs.keys())

# Check if the arguments from agent have descriptions in docstring
args_description = {_.arg_name: _.description for _ in
docstring.params}
args_description = {
_.arg_name: _.description for _ in docstring.params
}

# Prepare default values
args_defaults = {k: v for k, v in zip(reversed(argsspec.args),
reversed(argsspec.defaults))}
args_defaults = dict(
zip(
reversed(argsspec.args),
reversed(argsspec.defaults), # type: ignore
),
)

args_required = list(set(args_agent) - set(args_defaults.keys()))

# Prepare types of the arguments, remove the return type
args_types = {k: v for k, v in argsspec.annotations.items() if k !=
"return"}
args_types = {
k: v for k, v in argsspec.annotations.items() if k != "return"
}

# Prepare argument dictionary
properties_field = dict()
properties_field = {}
for key in args_agent:
property = dict()
arg_property = {}
# type
if key in args_types:
try:
required_type = _get_type_str(args_types[key])
property["type"] = required_type
arg_property["type"] = required_type
except Exception:
logger.warning(f"Fail and skip to get the type of the "
f"argument `{key}`.")

logger.warning(
f"Fail and skip to get the type of the "
f"argument `{key}`.",
)

# For Literal type, add enum field
if get_origin(args_types[key]) is Literal:
property["enum"] = list(args_types[key].__args__)
arg_property["enum"] = list(args_types[key].__args__)

# description
if key in args_description:
property["description"] = args_description[key]
arg_property["description"] = args_description[key]

# default
if key in args_defaults and args_defaults[key] is not None:
property["default"] = args_defaults[key]
arg_property["default"] = args_defaults[key]

properties_field[key] = property
properties_field[key] = arg_property

# Construct the JSON Schema for the service function
func_dict = {
Expand All @@ -150,9 +163,9 @@ def get(self, service_func: Callable[..., Any], **kwargs: Any) -> Tuple[
"parameters": {
"type": "object",
"properties": properties_field,
"required": args_required
}
}
"required": args_required,
},
},
}

return tool_func, func_dict
20 changes: 10 additions & 10 deletions src/agentscope/service/web_search/search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""Search question in the web"""
from typing import Optional, Any
from typing import Any

from agentscope.service.service_response import ServiceResponse
from agentscope.utils.common import requests_get
Expand All @@ -9,7 +9,7 @@

def bing_search(
question: str,
bing_api_key: str,
api_key: str,
num_results: int = 10,
**kwargs: Any,
) -> ServiceResponse:
Expand All @@ -19,7 +19,7 @@ def bing_search(
Args:
question (`str`):
The search query string.
bing_api_key (`str`):
api_key (`str`):
The API key provided for authenticating with the Bing Search API.
num_results (`int`, defaults to `10`):
The number of search results to return.
Expand Down Expand Up @@ -84,7 +84,7 @@ def bing_search(
if kwargs:
params.update(**kwargs)

headers = {"Ocp-Apim-Subscription-Key": bing_api_key}
headers = {"Ocp-Apim-Subscription-Key": api_key}

search_results = requests_get(
bing_search_url,
Expand Down Expand Up @@ -116,8 +116,8 @@ def bing_search(

def google_search(
question: str,
google_api_key: str,
google_cse_id: str,
api_key: str,
cse_id: str,
num_results: int = 10,
**kwargs: Any,
) -> ServiceResponse:
Expand All @@ -127,10 +127,10 @@ def google_search(
Args:
question (`str`):
The search query string.
google_api_key (`str`):
api_key (`str`):
The API key provided for authenticating with the Google Custom
Search JSON API.
google_cse_id (`str`):
cse_id (`str`):
The unique identifier of a programmable search engine to use.
num_results (`int`, defaults to `10`):
The number of search results to return.
Expand Down Expand Up @@ -167,8 +167,8 @@ def google_search(
# Define the query parameters
params = {
"q": question,
"key": google_api_key,
"cx": google_cse_id,
"key": api_key,
"cx": cse_id,
"num": num_results,
}
if kwargs:
Expand Down
Loading

0 comments on commit fec2543

Please sign in to comment.