Skip to content

Commit

Permalink
fix: switch to create_function_call function due to lack of default…
Browse files Browse the repository at this point in the history
… `type`
  • Loading branch information
mmikita95 committed Oct 16, 2024
1 parent 25b6c2d commit d7ffed1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 11 deletions.
10 changes: 8 additions & 2 deletions docs/framework/ai-module.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,13 @@ Framework allows you to register Python functions that can be called automatical

Function tools are defined using either a Python class or a JSON configuration.
```python
from writer.ai import FunctionTool
from writer.ai import create_function_tool

# Define a function tool with Python callable
def calculate_interest(principal: float, rate: float, time: float):
return principal * rate * time

tool = FunctionTool(
tool = create_function_tool(
name="calculate_interest",
callable=calculate_interest,
parameters={
Expand Down Expand Up @@ -217,6 +217,12 @@ Function tools require the following properties:

When a conversation involves a tool (either a graph or a function), Framework automatically handles the requests from LLM to use the tools during interactions. If the tool needs multiple steps (for example, querying data and processing it), Framework will handle those steps recursively, calling functions as needed until the final result is returned.

By default, to prevent endless recursion, Framework will only handle 3 consecutive tool calls. You can expand it in case it doesn't suit your case – both `complete()` and `stream_complete()` accept a `max_tool_depth` parameter, which configures the maximum allowed recursion depth:

```python
response = conversation.complete(tools=tool, max_tool_depth=7)
```

### Providing a Tool or a List of Tools

You can pass either a single tool or a list of tools to the `complete()` or `stream_complete()` methods. The tools can be a combination of FunctionTool, Graph, or JSON-defined tools.
Expand Down
53 changes: 44 additions & 9 deletions src/writer/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,26 @@ class GraphTool(Tool):
class FunctionTool(Tool):
callable: Callable
name: str
description: Optional[str]
parameters: Dict[str, Dict[str, str]]


def create_function_tool(
callable: Callable,
name: str,
parameters: Optional[Dict[str, Dict[str, str]]],
description: Optional[str] = None
) -> FunctionTool:
parameters = parameters or {}
return FunctionTool(
type="function",
callable=callable,
name=name,
description=description,
parameters=parameters
)


logger = logging.Logger(__name__, level=logging.DEBUG)


Expand Down Expand Up @@ -1190,7 +1207,7 @@ def _execute_function_tool_call(self, index: int) -> dict:
"role": "tool",
"name": function_name,
"tool_call_id": tool_call_id,
"content": f"{function_name}: {func_result}"
"content": f"{func_result}"
}

return follow_up_message
Expand Down Expand Up @@ -1223,7 +1240,18 @@ def _process_tool_call(self, index, tool_call_id, tool_call_name, tool_call_argu

# Accumulate arguments across chunks
if tool_call_arguments is not None:
self._ongoing_tool_calls[index]["arguments"] += tool_call_arguments
if (
tool_call_arguments.startswith("{")
and tool_call_arguments.endswith("}")
):
# For cases when LLM "bugs" and returns
# the whole arguments string as a last chunk
fixed_chunk = tool_call_arguments.rsplit("{")[-1]
self._ongoing_tool_calls[index]["arguments"] = \
"{" + fixed_chunk
else:
# Process normally
self._ongoing_tool_calls[index]["arguments"] += tool_call_arguments

# Check if we have all necessary data to execute the function
if (
Expand Down Expand Up @@ -1271,9 +1299,10 @@ def _process_response_data(
passed_messages: List[WriterAIMessage],
request_model: str,
request_data: ChatOptions,
depth=1
depth=1,
max_depth=3
) -> 'Conversation.Message':
if depth > 3:
if depth > max_depth:
raise RuntimeError("Reached maximum depth when processing response data tool calls.")
for entry in response_data.choices:
message = entry.message
Expand Down Expand Up @@ -1322,9 +1351,10 @@ def _process_stream_response(
request_model: str,
request_data: ChatOptions,
depth=1,
max_depth=3,
flag_chunks=False
) -> Generator[dict, None, None]:
if depth > 3:
if depth > max_depth:
raise RuntimeError("Reached maximum depth when processing response data tool calls.")
# We avoid flagging first chunk
# to trigger creating a message
Expand Down Expand Up @@ -1361,6 +1391,7 @@ def _process_stream_response(
request_model=request_model,
request_data=request_data,
depth=depth+1,
max_depth=max_depth,
flag_chunks=True
)
finally:
Expand All @@ -1384,7 +1415,8 @@ def complete(
FunctionTool,
List[Union[Graph, GraphTool, FunctionTool]]
] # can be an instance of tool or a list of instances
] = None
] = None,
max_tool_depth: int = 3,
) -> 'Conversation.Message':
"""
Processes the conversation with the current messages and additional data to generate a response.
Expand Down Expand Up @@ -1421,7 +1453,8 @@ def complete(
response_data,
passed_messages=passed_messages,
request_model=request_model,
request_data=request_data
request_data=request_data,
max_depth=max_tool_depth
)

def stream_complete(
Expand All @@ -1434,7 +1467,8 @@ def stream_complete(
FunctionTool,
List[Union[Graph, GraphTool, FunctionTool]]
] # can be an instance of tool or a list of instances
] = None
] = None,
max_tool_depth: int = 3
) -> Generator[dict, None, None]:
"""
Initiates a stream to receive chunks of the model's reply.
Expand Down Expand Up @@ -1474,7 +1508,8 @@ def stream_complete(
response=response,
passed_messages=passed_messages,
request_model=request_model,
request_data=request_data
request_data=request_data,
max_depth=max_tool_depth
)

response.close()
Expand Down

0 comments on commit d7ffed1

Please sign in to comment.