diff --git a/src/writer/ai.py b/src/writer/ai.py index bbdbbfbb8..557f6003d 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -90,17 +90,31 @@ class GraphTool(Tool): subqueries: bool +class FunctionToolParameterMeta(TypedDict): + type: Union[ + Literal["string"], + Literal["number"], + Literal["integer"], + Literal["float"], + Literal["boolean"], + Literal["array"], + Literal["object"], + Literal["null"] + ] + description: str + + class FunctionTool(Tool): callable: Callable name: str description: Optional[str] - parameters: Dict[str, Dict[str, str]] + parameters: Dict[str, FunctionToolParameterMeta] def create_function_tool( callable: Callable, name: str, - parameters: Optional[Dict[str, Dict[str, str]]] = None, + parameters: Optional[Dict[str, FunctionToolParameterMeta]] = None, description: Optional[str] = None ) -> FunctionTool: parameters = parameters or {} @@ -1196,7 +1210,7 @@ def _register_callable( self, callable_to_register: Callable, name: str, - parameters: Dict[str, Dict[str, str]] + parameters: Dict[str, FunctionToolParameterMeta] ): """ Internal helper function to store a provided callable @@ -1266,7 +1280,9 @@ def _prepare_tool( Internal helper function to process a tool instance into the required format. """ - def validate_parameters(parameters: Dict[str, Dict[str, str]]) -> bool: + def validate_parameters( + parameters: Dict[str, FunctionToolParameterMeta] + ) -> bool: """ Validates the `parameters` dictionary to ensure that each key is a parameter name, and each value is a dictionary containing