diff --git a/docs/source/conf.py b/docs/source/conf.py index d3c8d4d3..205f01ba 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ project = "Semantic Router" copyright = "2024, Aurelio AI" author = "Aurelio AI" -release = "0.0.58" +release = "0.0.59" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 82dca612..ebd6d952 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "semantic-router" -version = "0.0.58" +version = "0.0.59" description = "Super fast semantic router for AI decision making" authors = [ "James Briggs ", diff --git a/semantic_router/__init__.py b/semantic_router/__init__.py index cefeaa60..775a8d3a 100644 --- a/semantic_router/__init__.py +++ b/semantic_router/__init__.py @@ -4,4 +4,4 @@ __all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"] -__version__ = "0.0.58" +__version__ = "0.0.59" diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index 99c9d385..3ae14043 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -1,11 +1,106 @@ import inspect -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Optional, Union from pydantic.v1 import BaseModel from semantic_router.llms import BaseLLM from semantic_router.schema import Message, RouteChoice from semantic_router.utils.logger import logger +from pydantic import Field + + +class Parameter(BaseModel): + class Config: + arbitrary_types_allowed = True + + name: str = Field(description="The name of the parameter") + description: Optional[str] = Field( + default=None, description="The description of the parameter" + ) + type: str = Field(description="The type of the parameter") + default: Any = Field(description="The default value of the parameter") + required: bool = Field(description="Whether the parameter is required") + + def to_ollama(self): + return { + self.name: { + "description": self.description, + "type": self.type, + } + } + + +class FunctionSchema: + """Class that consumes a function and can return a schema required by + different LLMs for function calling. + """ + + name: str = Field(description="The name of the function") + description: str = Field(description="The description of the function") + signature: str = Field(description="The signature of the function") + output: str = Field(description="The output of the function") + parameters: List[Parameter] = Field(description="The parameters of the function") + + def __init__(self, function: Union[Callable, BaseModel]): + self.function = function + if callable(function): + self._process_function(function) + elif isinstance(function, BaseModel): + raise NotImplementedError("Pydantic BaseModel not implemented yet.") + else: + raise TypeError("Function must be a Callable or BaseModel") + + def _process_function(self, function: Callable): + self.name = function.__name__ + self.description = str(inspect.getdoc(function)) + self.signature = str(inspect.signature(function)) + self.output = str(inspect.signature(function).return_annotation) + parameters = [] + for param in inspect.signature(function).parameters.values(): + parameters.append( + Parameter( + name=param.name, + type=param.annotation.__name__, + default=param.default, + required=False if param.default is param.empty else True, + ) + ) + self.parameters = parameters + + def to_ollama(self): + schema_dict = { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": { + param.name: { + "description": param.description, + "type": self._ollama_type_mapping(param.type), + } + for param in self.parameters + }, + "required": [ + param.name for param in self.parameters if param.required + ], + }, + }, + } + return schema_dict + + def _ollama_type_mapping(self, param_type: str) -> str: + if param_type == "int": + return "number" + elif param_type == "float": + return "number" + elif param_type == "str": + return "string" + elif param_type == "bool": + return "boolean" + else: + return "object" def get_schema_list(items: List[Union[BaseModel, Callable]]) -> List[Dict[str, Any]]: diff --git a/tests/unit/test_function_schema.py b/tests/unit/test_function_schema.py new file mode 100644 index 00000000..94a3f8e8 --- /dev/null +++ b/tests/unit/test_function_schema.py @@ -0,0 +1,44 @@ +import inspect +from semantic_router.utils.function_call import FunctionSchema + + +def scrape_webpage(url: str, name: str = "test") -> str: + """Provides access to web scraping. You can use this tool to scrape a webpage. + Many webpages may return no information due to JS or adblock issues, if this + happens, you must use a different URL. + """ + return "hello there" + + +def test_function_schema(): + schema = FunctionSchema(scrape_webpage) + assert schema.name == scrape_webpage.__name__ + assert schema.description == str(inspect.getdoc(scrape_webpage)) + assert schema.signature == str(inspect.signature(scrape_webpage)) + assert schema.output == str(inspect.signature(scrape_webpage).return_annotation) + assert len(schema.parameters) == 2 + + +def test_ollama_function_schema(): + schema = FunctionSchema(scrape_webpage) + ollama_schema = schema.to_ollama() + assert ollama_schema["type"] == "function" + assert ollama_schema["function"]["name"] == schema.name + assert ollama_schema["function"]["description"] == schema.description + assert ollama_schema["function"]["parameters"]["type"] == "object" + assert ( + ollama_schema["function"]["parameters"]["properties"]["url"]["type"] == "string" + ) + assert ( + ollama_schema["function"]["parameters"]["properties"]["name"]["type"] + == "string" + ) + assert ( + ollama_schema["function"]["parameters"]["properties"]["url"]["description"] + is None + ) + assert ( + ollama_schema["function"]["parameters"]["properties"]["name"]["description"] + is None + ) + assert ollama_schema["function"]["parameters"]["required"] == ["name"]