Skip to content

Commit

Permalink
Merge pull request #675 from mmikita95/fix-no-arguments-tool-calls
Browse files Browse the repository at this point in the history
fix: properly process tool calls with zero arguments
  • Loading branch information
ramedina86 authored Dec 4, 2024
2 parents 8724ea5 + 9930c13 commit 83a5153
Showing 1 changed file with 72 additions and 6 deletions.
78 changes: 72 additions & 6 deletions src/writer/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class FunctionTool(Tool):
def create_function_tool(
callable: Callable,
name: str,
parameters: Optional[Dict[str, Dict[str, str]]],
parameters: Optional[Dict[str, Dict[str, str]]] = None,
description: Optional[str] = None
) -> FunctionTool:
parameters = parameters or {}
Expand Down Expand Up @@ -1240,7 +1240,17 @@ def _tool_calls_ready(self):

def _gather_tool_calls_messages(self):
return {
index: ongoing_tool_call.get("res")
index: ongoing_tool_call.get(
"res",
{
"role": "tool",
"content": "ERROR: Failed to get function call result – " +
"the function was never called. The most likely reason " +
"is LLM never issuing a `finish_reason: 'tool_calls'`. " +
"Please DO NOT RETRY the function call and inform " +
"the user about the error."
}
)
for index, ongoing_tool_call
in self._ongoing_tool_calls.items()
}
Expand Down Expand Up @@ -1534,6 +1544,19 @@ def _convert_argument_to_type(self, value: Any, target_type: str) -> Any:
else:
raise ValueError(f"Unsupported target type: {target_type}")

def _check_if_arguments_are_required(self, function_name: str) -> bool:
callable_entry = self._callable_registry.get(function_name)
if not callable_entry:
raise RuntimeError(
f"Tried to check arguments of function {function_name} " +
"which is not present in the conversation's callable registry."
)
callable_parameters = callable_entry.get("parameters")
return \
callable_parameters is not None \
and \
callable_parameters != {}

def _execute_function_tool_call(self, index: int) -> dict:
"""
Executes the function call for the specified tool call index.
Expand All @@ -1547,7 +1570,14 @@ def _execute_function_tool_call(self, index: int) -> dict:

# Parse arguments and execute callable
try:
parsed_arguments = json.loads(arguments)
if (
not arguments
and
not self._check_if_arguments_are_required(function_name)
):
parsed_arguments = {}
else:
parsed_arguments = json.loads(arguments)
callable_entry = self._callable_registry.get(function_name)

if callable_entry:
Expand Down Expand Up @@ -1657,12 +1687,48 @@ def _process_tool_call(
tool_call_arguments

# Check if we have all necessary data to execute the function
tool_call_id, tool_call_name, tool_call_arguments = \
self._ongoing_tool_calls[index]["tool_call_id"], \
self._ongoing_tool_calls[index]["name"], \
self._ongoing_tool_calls[index]["arguments"]

tool_call_id_ready = tool_call_id is not None
tool_call_name_ready = tool_call_name is not None

# Check whether the arguments are prepared properly -
# either present in correct format
# or should not be used due to not being required for the function
if tool_call_name_ready:
# Function name is needed to check the function for params
tool_call_arguments_not_required = \
(
not tool_call_arguments
and
not self._check_if_arguments_are_required(
tool_call_name
)
)
tool_call_arguments_formatted_properly = \
(
isinstance(
tool_call_arguments, str
)
and
tool_call_arguments.endswith("}")
)
tool_call_arguments_ready = \
tool_call_arguments_not_required \
or \
tool_call_arguments_formatted_properly
else:
tool_call_arguments_ready = False

if (
self._ongoing_tool_calls[index]["tool_call_id"] is not None
tool_call_id_ready
and
self._ongoing_tool_calls[index]["name"] is not None
tool_call_name_ready
and
self._ongoing_tool_calls[index]["arguments"].endswith("}")
tool_call_arguments_ready
):
follow_up_message = self._execute_function_tool_call(index)
if follow_up_message:
Expand Down

0 comments on commit 83a5153

Please sign in to comment.