diff --git a/src/writer/ai.py b/src/writer/ai.py index c86bf1753..bbdbbfbb8 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -100,7 +100,7 @@ class FunctionTool(Tool): def create_function_tool( callable: Callable, name: str, - parameters: Optional[Dict[str, Dict[str, str]]], + parameters: Optional[Dict[str, Dict[str, str]]] = None, description: Optional[str] = None ) -> FunctionTool: parameters = parameters or {} @@ -1240,7 +1240,17 @@ def _tool_calls_ready(self): def _gather_tool_calls_messages(self): return { - index: ongoing_tool_call.get("res") + index: ongoing_tool_call.get( + "res", + { + "role": "tool", + "content": "ERROR: Failed to get function call result – " + + "the function was never called. The most likely reason " + + "is LLM never issuing a `finish_reason: 'tool_calls'`. " + + "Please DO NOT RETRY the function call and inform " + + "the user about the error." + } + ) for index, ongoing_tool_call in self._ongoing_tool_calls.items() } @@ -1534,6 +1544,19 @@ def _convert_argument_to_type(self, value: Any, target_type: str) -> Any: else: raise ValueError(f"Unsupported target type: {target_type}") + def _check_if_arguments_are_required(self, function_name: str) -> bool: + callable_entry = self._callable_registry.get(function_name) + if not callable_entry: + raise RuntimeError( + f"Tried to check arguments of function {function_name} " + + "which is not present in the conversation's callable registry." + ) + callable_parameters = callable_entry.get("parameters") + return \ + callable_parameters is not None \ + and \ + callable_parameters != {} + def _execute_function_tool_call(self, index: int) -> dict: """ Executes the function call for the specified tool call index. @@ -1547,7 +1570,14 @@ def _execute_function_tool_call(self, index: int) -> dict: # Parse arguments and execute callable try: - parsed_arguments = json.loads(arguments) + if ( + not arguments + and + not self._check_if_arguments_are_required(function_name) + ): + parsed_arguments = {} + else: + parsed_arguments = json.loads(arguments) callable_entry = self._callable_registry.get(function_name) if callable_entry: @@ -1657,12 +1687,48 @@ def _process_tool_call( tool_call_arguments # Check if we have all necessary data to execute the function + tool_call_id, tool_call_name, tool_call_arguments = \ + self._ongoing_tool_calls[index]["tool_call_id"], \ + self._ongoing_tool_calls[index]["name"], \ + self._ongoing_tool_calls[index]["arguments"] + + tool_call_id_ready = tool_call_id is not None + tool_call_name_ready = tool_call_name is not None + + # Check whether the arguments are prepared properly - + # either present in correct format + # or should not be used due to not being required for the function + if tool_call_name_ready: + # Function name is needed to check the function for params + tool_call_arguments_not_required = \ + ( + not tool_call_arguments + and + not self._check_if_arguments_are_required( + tool_call_name + ) + ) + tool_call_arguments_formatted_properly = \ + ( + isinstance( + tool_call_arguments, str + ) + and + tool_call_arguments.endswith("}") + ) + tool_call_arguments_ready = \ + tool_call_arguments_not_required \ + or \ + tool_call_arguments_formatted_properly + else: + tool_call_arguments_ready = False + if ( - self._ongoing_tool_calls[index]["tool_call_id"] is not None + tool_call_id_ready and - self._ongoing_tool_calls[index]["name"] is not None + tool_call_name_ready and - self._ongoing_tool_calls[index]["arguments"].endswith("}") + tool_call_arguments_ready ): follow_up_message = self._execute_function_tool_call(index) if follow_up_message: