diff --git a/docs/framework/ai-module.mdx b/docs/framework/ai-module.mdx index 857d6147a..4754f5725 100644 --- a/docs/framework/ai-module.mdx +++ b/docs/framework/ai-module.mdx @@ -173,13 +173,13 @@ Framework allows you to register Python functions that can be called automatical Function tools are defined using either a Python class or a JSON configuration. ```python -from writer.ai import FunctionTool +from writer.ai import create_function_tool # Define a function tool with Python callable def calculate_interest(principal: float, rate: float, time: float): return principal * rate * time -tool = FunctionTool( +tool = create_function_tool( name="calculate_interest", callable=calculate_interest, parameters={ @@ -217,6 +217,12 @@ Function tools require the following properties: When a conversation involves a tool (either a graph or a function), Framework automatically handles the requests from LLM to use the tools during interactions. If the tool needs multiple steps (for example, querying data and processing it), Framework will handle those steps recursively, calling functions as needed until the final result is returned. +By default, to prevent endless recursion, Framework will only handle 3 consecutive tool calls. You can expand it in case it doesn't suit your case – both `complete()` and `stream_complete()` accept a `max_tool_depth` parameter, which configures the maximum allowed recursion depth: + +```python +response = conversation.complete(tools=tool, max_tool_depth=7) +``` + ### Providing a Tool or a List of Tools You can pass either a single tool or a list of tools to the `complete()` or `stream_complete()` methods. The tools can be a combination of FunctionTool, Graph, or JSON-defined tools. diff --git a/src/writer/ai.py b/src/writer/ai.py index f304fa676..eba5d9768 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -90,9 +90,26 @@ class GraphTool(Tool): class FunctionTool(Tool): callable: Callable name: str + description: Optional[str] parameters: Dict[str, Dict[str, str]] +def create_function_tool( + callable: Callable, + name: str, + parameters: Optional[Dict[str, Dict[str, str]]], + description: Optional[str] = None +) -> FunctionTool: + parameters = parameters or {} + return FunctionTool( + type="function", + callable=callable, + name=name, + description=description, + parameters=parameters + ) + + logger = logging.Logger(__name__, level=logging.DEBUG) @@ -879,6 +896,9 @@ def _clear_callable_registry(self): self._callable_registry = {} def _clear_ongoing_tool_calls(self): + """ + Clear ongoing tool calls after they've been processed + """ self._ongoing_tool_calls = {} def _clear_tool_calls_helpers(self): @@ -1190,7 +1210,7 @@ def _execute_function_tool_call(self, index: int) -> dict: "role": "tool", "name": function_name, "tool_call_id": tool_call_id, - "content": f"{function_name}: {func_result}" + "content": f"{func_result}" } return follow_up_message @@ -1223,7 +1243,18 @@ def _process_tool_call(self, index, tool_call_id, tool_call_name, tool_call_argu # Accumulate arguments across chunks if tool_call_arguments is not None: - self._ongoing_tool_calls[index]["arguments"] += tool_call_arguments + if ( + tool_call_arguments.startswith("{") + and tool_call_arguments.endswith("}") + ): + # For cases when LLM "bugs" and returns + # the whole arguments string as a last chunk + fixed_chunk = tool_call_arguments.rsplit("{")[-1] + self._ongoing_tool_calls[index]["arguments"] = \ + "{" + fixed_chunk + else: + # Process normally + self._ongoing_tool_calls[index]["arguments"] += tool_call_arguments # Check if we have all necessary data to execute the function if ( @@ -1271,9 +1302,10 @@ def _process_response_data( passed_messages: List[WriterAIMessage], request_model: str, request_data: ChatOptions, - depth=1 + depth=1, + max_depth=3 ) -> 'Conversation.Message': - if depth > 3: + if depth > max_depth: raise RuntimeError("Reached maximum depth when processing response data tool calls.") for entry in response_data.choices: message = entry.message @@ -1299,11 +1331,10 @@ def _process_response_data( ) logger.debug(f"Received response – {follow_up_response}") - # Clear buffer and callable registry for the completed tool call - self._clear_tool_calls_helpers() - # Call the function recursively to either process a new tool call # or return the message if no tool calls are requested + + self._clear_ongoing_tool_calls() return self._process_response_data( follow_up_response, passed_messages=passed_messages, @@ -1322,9 +1353,10 @@ def _process_stream_response( request_model: str, request_data: ChatOptions, depth=1, + max_depth=3, flag_chunks=False ) -> Generator[dict, None, None]: - if depth > 3: + if depth > max_depth: raise RuntimeError("Reached maximum depth when processing response data tool calls.") # We avoid flagging first chunk # to trigger creating a message @@ -1352,15 +1384,15 @@ def _process_stream_response( ) ) - # Clear buffer and callable registry for the completed tool call try: - self._clear_tool_calls_helpers() + self._clear_ongoing_tool_calls() yield from self._process_stream_response( response=follow_up_response, passed_messages=passed_messages, request_model=request_model, request_data=request_data, depth=depth+1, + max_depth=max_depth, flag_chunks=True ) finally: @@ -1384,7 +1416,8 @@ def complete( FunctionTool, List[Union[Graph, GraphTool, FunctionTool]] ] # can be an instance of tool or a list of instances - ] = None + ] = None, + max_tool_depth: int = 3, ) -> 'Conversation.Message': """ Processes the conversation with the current messages and additional data to generate a response. @@ -1417,13 +1450,19 @@ def complete( ) ) - return self._process_response_data( + response = self._process_response_data( response_data, passed_messages=passed_messages, request_model=request_model, - request_data=request_data + request_data=request_data, + max_depth=max_tool_depth ) + # Clear buffer and callable registry for the completed tool call + self._clear_tool_calls_helpers() + + return response + def stream_complete( self, config: Optional['ChatOptions'] = None, @@ -1434,7 +1473,8 @@ def stream_complete( FunctionTool, List[Union[Graph, GraphTool, FunctionTool]] ] # can be an instance of tool or a list of instances - ] = None + ] = None, + max_tool_depth: int = 3 ) -> Generator[dict, None, None]: """ Initiates a stream to receive chunks of the model's reply. @@ -1474,9 +1514,12 @@ def stream_complete( response=response, passed_messages=passed_messages, request_model=request_model, - request_data=request_data + request_data=request_data, + max_depth=max_tool_depth ) + # Clear buffer and callable registry for the completed tool call + self._clear_tool_calls_helpers() response.close() @property