Skip to content

Commit

Permalink
Merge pull request #598 from mmikita95/fix-ai-module-tool-type
Browse files Browse the repository at this point in the history
fix: switch to `create_function_call` function due to lack of default `type`
  • Loading branch information
ramedina86 authored Oct 16, 2024
2 parents 25b6c2d + e827c2b commit d82f245
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 17 deletions.
10 changes: 8 additions & 2 deletions docs/framework/ai-module.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,13 @@ Framework allows you to register Python functions that can be called automatical

Function tools are defined using either a Python class or a JSON configuration.
```python
from writer.ai import FunctionTool
from writer.ai import create_function_tool

# Define a function tool with Python callable
def calculate_interest(principal: float, rate: float, time: float):
return principal * rate * time

tool = FunctionTool(
tool = create_function_tool(
name="calculate_interest",
callable=calculate_interest,
parameters={
Expand Down Expand Up @@ -217,6 +217,12 @@ Function tools require the following properties:

When a conversation involves a tool (either a graph or a function), Framework automatically handles the requests from LLM to use the tools during interactions. If the tool needs multiple steps (for example, querying data and processing it), Framework will handle those steps recursively, calling functions as needed until the final result is returned.

By default, to prevent endless recursion, Framework will only handle 3 consecutive tool calls. You can expand it in case it doesn't suit your case – both `complete()` and `stream_complete()` accept a `max_tool_depth` parameter, which configures the maximum allowed recursion depth:

```python
response = conversation.complete(tools=tool, max_tool_depth=7)
```

### Providing a Tool or a List of Tools

You can pass either a single tool or a list of tools to the `complete()` or `stream_complete()` methods. The tools can be a combination of FunctionTool, Graph, or JSON-defined tools.
Expand Down
73 changes: 58 additions & 15 deletions src/writer/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,26 @@ class GraphTool(Tool):
class FunctionTool(Tool):
callable: Callable
name: str
description: Optional[str]
parameters: Dict[str, Dict[str, str]]


def create_function_tool(
callable: Callable,
name: str,
parameters: Optional[Dict[str, Dict[str, str]]],
description: Optional[str] = None
) -> FunctionTool:
parameters = parameters or {}
return FunctionTool(
type="function",
callable=callable,
name=name,
description=description,
parameters=parameters
)


logger = logging.Logger(__name__, level=logging.DEBUG)


Expand Down Expand Up @@ -879,6 +896,9 @@ def _clear_callable_registry(self):
self._callable_registry = {}

def _clear_ongoing_tool_calls(self):
"""
Clear ongoing tool calls after they've been processed
"""
self._ongoing_tool_calls = {}

def _clear_tool_calls_helpers(self):
Expand Down Expand Up @@ -1190,7 +1210,7 @@ def _execute_function_tool_call(self, index: int) -> dict:
"role": "tool",
"name": function_name,
"tool_call_id": tool_call_id,
"content": f"{function_name}: {func_result}"
"content": f"{func_result}"
}

return follow_up_message
Expand Down Expand Up @@ -1223,7 +1243,18 @@ def _process_tool_call(self, index, tool_call_id, tool_call_name, tool_call_argu

# Accumulate arguments across chunks
if tool_call_arguments is not None:
self._ongoing_tool_calls[index]["arguments"] += tool_call_arguments
if (
tool_call_arguments.startswith("{")
and tool_call_arguments.endswith("}")
):
# For cases when LLM "bugs" and returns
# the whole arguments string as a last chunk
fixed_chunk = tool_call_arguments.rsplit("{")[-1]
self._ongoing_tool_calls[index]["arguments"] = \
"{" + fixed_chunk
else:
# Process normally
self._ongoing_tool_calls[index]["arguments"] += tool_call_arguments

# Check if we have all necessary data to execute the function
if (
Expand Down Expand Up @@ -1271,9 +1302,10 @@ def _process_response_data(
passed_messages: List[WriterAIMessage],
request_model: str,
request_data: ChatOptions,
depth=1
depth=1,
max_depth=3
) -> 'Conversation.Message':
if depth > 3:
if depth > max_depth:
raise RuntimeError("Reached maximum depth when processing response data tool calls.")
for entry in response_data.choices:
message = entry.message
Expand All @@ -1299,11 +1331,10 @@ def _process_response_data(
)
logger.debug(f"Received response – {follow_up_response}")

# Clear buffer and callable registry for the completed tool call
self._clear_tool_calls_helpers()

# Call the function recursively to either process a new tool call
# or return the message if no tool calls are requested

self._clear_ongoing_tool_calls()
return self._process_response_data(
follow_up_response,
passed_messages=passed_messages,
Expand All @@ -1322,9 +1353,10 @@ def _process_stream_response(
request_model: str,
request_data: ChatOptions,
depth=1,
max_depth=3,
flag_chunks=False
) -> Generator[dict, None, None]:
if depth > 3:
if depth > max_depth:
raise RuntimeError("Reached maximum depth when processing response data tool calls.")
# We avoid flagging first chunk
# to trigger creating a message
Expand Down Expand Up @@ -1352,15 +1384,15 @@ def _process_stream_response(
)
)

# Clear buffer and callable registry for the completed tool call
try:
self._clear_tool_calls_helpers()
self._clear_ongoing_tool_calls()
yield from self._process_stream_response(
response=follow_up_response,
passed_messages=passed_messages,
request_model=request_model,
request_data=request_data,
depth=depth+1,
max_depth=max_depth,
flag_chunks=True
)
finally:
Expand All @@ -1384,7 +1416,8 @@ def complete(
FunctionTool,
List[Union[Graph, GraphTool, FunctionTool]]
] # can be an instance of tool or a list of instances
] = None
] = None,
max_tool_depth: int = 3,
) -> 'Conversation.Message':
"""
Processes the conversation with the current messages and additional data to generate a response.
Expand Down Expand Up @@ -1417,13 +1450,19 @@ def complete(
)
)

return self._process_response_data(
response = self._process_response_data(
response_data,
passed_messages=passed_messages,
request_model=request_model,
request_data=request_data
request_data=request_data,
max_depth=max_tool_depth
)

# Clear buffer and callable registry for the completed tool call
self._clear_tool_calls_helpers()

return response

def stream_complete(
self,
config: Optional['ChatOptions'] = None,
Expand All @@ -1434,7 +1473,8 @@ def stream_complete(
FunctionTool,
List[Union[Graph, GraphTool, FunctionTool]]
] # can be an instance of tool or a list of instances
] = None
] = None,
max_tool_depth: int = 3
) -> Generator[dict, None, None]:
"""
Initiates a stream to receive chunks of the model's reply.
Expand Down Expand Up @@ -1474,9 +1514,12 @@ def stream_complete(
response=response,
passed_messages=passed_messages,
request_model=request_model,
request_data=request_data
request_data=request_data,
max_depth=max_tool_depth
)

# Clear buffer and callable registry for the completed tool call
self._clear_tool_calls_helpers()
response.close()

@property
Expand Down

0 comments on commit d82f245

Please sign in to comment.