diff --git a/docs/framework/ai-module.mdx b/docs/framework/ai-module.mdx index 857d6147a..4754f5725 100644 --- a/docs/framework/ai-module.mdx +++ b/docs/framework/ai-module.mdx @@ -173,13 +173,13 @@ Framework allows you to register Python functions that can be called automatical Function tools are defined using either a Python class or a JSON configuration. ```python -from writer.ai import FunctionTool +from writer.ai import create_function_tool # Define a function tool with Python callable def calculate_interest(principal: float, rate: float, time: float): return principal * rate * time -tool = FunctionTool( +tool = create_function_tool( name="calculate_interest", callable=calculate_interest, parameters={ @@ -217,6 +217,12 @@ Function tools require the following properties: When a conversation involves a tool (either a graph or a function), Framework automatically handles the requests from LLM to use the tools during interactions. If the tool needs multiple steps (for example, querying data and processing it), Framework will handle those steps recursively, calling functions as needed until the final result is returned. +By default, to prevent endless recursion, Framework will only handle 3 consecutive tool calls. You can expand it in case it doesn't suit your case – both `complete()` and `stream_complete()` accept a `max_tool_depth` parameter, which configures the maximum allowed recursion depth: + +```python +response = conversation.complete(tools=tool, max_tool_depth=7) +``` + ### Providing a Tool or a List of Tools You can pass either a single tool or a list of tools to the `complete()` or `stream_complete()` methods. The tools can be a combination of FunctionTool, Graph, or JSON-defined tools. diff --git a/src/writer/ai.py b/src/writer/ai.py index f304fa676..f01cd67c9 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -90,9 +90,26 @@ class GraphTool(Tool): class FunctionTool(Tool): callable: Callable name: str + description: Optional[str] parameters: Dict[str, Dict[str, str]] +def create_function_tool( + callable: Callable, + name: str, + parameters: Optional[Dict[str, Dict[str, str]]], + description: Optional[str] = None +) -> FunctionTool: + parameters = parameters or {} + return FunctionTool( + type="function", + callable=callable, + name=name, + description=description, + parameters=parameters + ) + + logger = logging.Logger(__name__, level=logging.DEBUG) @@ -1190,7 +1207,7 @@ def _execute_function_tool_call(self, index: int) -> dict: "role": "tool", "name": function_name, "tool_call_id": tool_call_id, - "content": f"{function_name}: {func_result}" + "content": f"{func_result}" } return follow_up_message @@ -1223,7 +1240,18 @@ def _process_tool_call(self, index, tool_call_id, tool_call_name, tool_call_argu # Accumulate arguments across chunks if tool_call_arguments is not None: - self._ongoing_tool_calls[index]["arguments"] += tool_call_arguments + if ( + tool_call_arguments.startswith("{") + and tool_call_arguments.endswith("}") + ): + # For cases when LLM "bugs" and returns + # the whole arguments string as a last chunk + fixed_chunk = tool_call_arguments.rsplit("{")[-1] + self._ongoing_tool_calls[index]["arguments"] = \ + "{" + fixed_chunk + else: + # Process normally + self._ongoing_tool_calls[index]["arguments"] += tool_call_arguments # Check if we have all necessary data to execute the function if ( @@ -1271,9 +1299,10 @@ def _process_response_data( passed_messages: List[WriterAIMessage], request_model: str, request_data: ChatOptions, - depth=1 + depth=1, + max_depth=3 ) -> 'Conversation.Message': - if depth > 3: + if depth > max_depth: raise RuntimeError("Reached maximum depth when processing response data tool calls.") for entry in response_data.choices: message = entry.message @@ -1322,9 +1351,10 @@ def _process_stream_response( request_model: str, request_data: ChatOptions, depth=1, + max_depth=3, flag_chunks=False ) -> Generator[dict, None, None]: - if depth > 3: + if depth > max_depth: raise RuntimeError("Reached maximum depth when processing response data tool calls.") # We avoid flagging first chunk # to trigger creating a message @@ -1361,6 +1391,7 @@ def _process_stream_response( request_model=request_model, request_data=request_data, depth=depth+1, + max_depth=max_depth, flag_chunks=True ) finally: @@ -1384,7 +1415,8 @@ def complete( FunctionTool, List[Union[Graph, GraphTool, FunctionTool]] ] # can be an instance of tool or a list of instances - ] = None + ] = None, + max_tool_depth: int = 3, ) -> 'Conversation.Message': """ Processes the conversation with the current messages and additional data to generate a response. @@ -1421,7 +1453,8 @@ def complete( response_data, passed_messages=passed_messages, request_model=request_model, - request_data=request_data + request_data=request_data, + max_depth=max_tool_depth ) def stream_complete( @@ -1434,7 +1467,8 @@ def stream_complete( FunctionTool, List[Union[Graph, GraphTool, FunctionTool]] ] # can be an instance of tool or a list of instances - ] = None + ] = None, + max_tool_depth: int = 3 ) -> Generator[dict, None, None]: """ Initiates a stream to receive chunks of the model's reply. @@ -1474,7 +1508,8 @@ def stream_complete( response=response, passed_messages=passed_messages, request_model=request_model, - request_data=request_data + request_data=request_data, + max_depth=max_tool_depth ) response.close()