Skip to content

Commit

Permalink
clean up code in functool and add decorator to quickly handle the gra…
Browse files Browse the repository at this point in the history
…dient pass through grad component correction, the code follows the 3 cases of gradcomponent design now, still need to further optimize the agent
  • Loading branch information
liyin2015 committed Dec 28, 2024
1 parent 04cd65b commit 2f19a14
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 95 deletions.
17 changes: 6 additions & 11 deletions adalflow/adalflow/components/agent/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def call(
self,
action_str: FunctionExpression,
step: int,
result: Union[FunctionOutput, Parameter],
result: FunctionOutput,
func: Function,
id: Optional[str] = None,
) -> StepOutput:
Expand All @@ -143,14 +143,8 @@ def call(
raise ValueError(f"Expected FunctionExpression, but got {type(action_str)}")
step_output.action = action_str
step_output.function = func
# printc(f"result: {result}", color="blue")
result = result.data if isinstance(result, Parameter) else result
if isinstance(result, FunctionOutput):
step_output.observation = (
result.output.data
if isinstance(result.output, Parameter)
else result.output
)

step_output.observation = result.output

return step_output

Expand Down Expand Up @@ -365,8 +359,6 @@ def _execute_action(
if isinstance(func, Parameter):
func.data.kwargs["id"] = id

func.add_successor_map_fn(self.tool_manager, lambda x: x.data)

result: Parameter = self.tool_manager(expr_or_fun=func, step="execute")
# printc(f"result: {result}", color="red")
result.add_successor_map_fn(
Expand All @@ -375,6 +367,9 @@ def _execute_action(
response.add_successor_map_fn(
successor=function_output_to_step_output, map_fn=lambda x: x.data
)
func.add_successor_map_fn(
successor=function_output_to_step_output, map_fn=lambda x: x.data
)
action_step = function_output_to_step_output.forward(
action_str=response,
step=action_step.step,
Expand Down
8 changes: 8 additions & 0 deletions adalflow/adalflow/core/func_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def retriever_as_tool(input: str) -> str:
- via sandboxed execute directionly using ``sandbox_exec``.
A FunctionTool allows other GradComponent(as a tool) to pass through correctly.
"""

def __init__(
Expand Down Expand Up @@ -237,6 +238,9 @@ def sync_function_1():
if self._is_async:
raise ValueError("FunctionTool is asynchronous, use acall instead")
output, error = None, None

# NOTE: special case:
# self.fn can have both train and eval mode or untrainable as a function.
try:
output = self.fn(*args, **kwargs)
except Exception as e:
Expand All @@ -247,6 +251,10 @@ def sync_function_1():
print(f"typeof output: {type(output)}")

if isinstance(output, Parameter):
if not self.training:
raise ValueError(
f"FunctionTool {self.definition.func_name} is in eval mode, but the output is Parameter"
)
print("output is Parameter")
output.data = FunctionOutput(
name=self.definition.func_name,
Expand Down
146 changes: 76 additions & 70 deletions adalflow/adalflow/core/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from copy import deepcopy
import asyncio
from adalflow.optim.parameter import Parameter, ParameterType, OutputParameter
from adalflow.optim.parameter import Parameter, ParameterType
import nest_asyncio
import warnings

Expand Down Expand Up @@ -57,6 +57,71 @@ def run_async_in_new_loop(coro):
asyncio.set_event_loop(None)


class CallFunctionTool(GradComponent):
__doc__ = """Contains other unit gradcomponent such as calling
a FunctionTool"""

def __init__(self):
super().__init__()

def forward(self, func: Parameter, context: Dict[str, object]):
return self.bicall(func, context=context)

def call(self, func: Function, context: Dict[str, object]) -> FunctionOutput:
return self.bicall(func, context=context)

def bicall(
self,
func: Union[Function, Parameter],
context: Dict[str, object] = {},
):
if isinstance(func, Parameter):
# data = func.successor_map_fn(func)
printc(f"context: {context}", color="yellow")
tool: FunctionTool = context[func.data.name]
print(f"tool training: {tool.training}")
output = tool.forward(*func.data.args, **func.data.kwargs)

from adalflow.optim.grad_component import fun_to_grad_component

# this will automatically create the outputparam, and connect output, func to the outputParam
@fun_to_grad_component
def dummy_pass_through_for_untrainable_fn(output, func):
return output

# NOTE: special case: handle the function which is not a grad_component
# here we have to specifically converts it to a parameter and handles the predecessors
# there is no trainable parameters inside of the tool but the tool response itself can be optimized by response optimizer
if not isinstance(output, Parameter):
return dummy_pass_through_for_untrainable_fn.forward(output, func)
else:
# reconnect the predecessor for tracing as it is not done in tool.forward
output.predecessors.add(func)
return output
else:
tool: FunctionTool = context[func.name]
output = tool.call(*func.args, **func.kwargs)
return output


class FunctionExperssionToFunction(GradComponent):
def __init__(self):
super().__init__()

def call(self, expr: FunctionExpression, context: Dict[str, object]):
print("DummpyGradComponent call")
print(expr)

expr_str = expr.action
func_name, args, kwargs = parse_function_call_expr(expr_str, context)
return Function(
name=func_name,
args=args,
kwargs=kwargs,
thought=expr.thought,
)


# TODO: good to track all the failed function calls
# Tool manager is a task component
class ToolManager(GradComponent):
Expand Down Expand Up @@ -149,35 +214,10 @@ def parse_func_expr(
if isinstance(expr, Parameter):
try:

class FunctionExperssionToFunction(GradComponent):
def __init__(self):
super().__init__()

def call(
self, expr: FunctionExpression, context: Dict[str, object]
):
print("DummpyGradComponent call")
print(expr)

expr_str = expr.action
func_name, args, kwargs = parse_function_call_expr(
expr_str, context
)
return Function(
name=func_name,
args=args,
kwargs=kwargs,
thought=expr.thought,
)

dummy = FunctionExperssionToFunction()
print("FunctionExperssionToFunction")
# expr.add_successor_map_fn(dummy, map_fn=lambda x: x.data)
return dummy.forward(expr, context=self.context)

# expr_str = expr.action
# func_name, args, kwargs = parse_function_call_expr(expr_str, self.context)
# return Function(name=func_name, args=args, kwargs=kwargs)
except Exception as e:
log.error(f"Error {e} parsing function call expression: {expr}")
raise ValueError(f"Error {e} parsing function call expression: {expr}")
Expand Down Expand Up @@ -218,6 +258,10 @@ def call(
expr_or_fun: Union[FunctionExpression, Function],
step: Literal["execute"] = "execute",
) -> Union[FunctionOutput, Function, Parameter]:
if not isinstance(expr_or_fun, (Function, FunctionExpression)):
raise ValueError(
f"expr_or_fun should be either a Function or FunctionExpression. Got {expr_or_fun}"
)
if step == "parse":
if isinstance(expr_or_fun, Function):
return expr_or_fun
Expand All @@ -229,8 +273,9 @@ def call(

def forward(
self,
*,
expr_or_fun: Union[FunctionExpression, Function, Parameter],
step: str = "execute",
step: Literal["parse", "execute"] = "execute",
) -> Union[FunctionOutput, Function, Parameter]:
if isinstance(expr_or_fun, Parameter):
if step == "execute":
Expand All @@ -248,7 +293,8 @@ def forward(
f"Only function call expressions are supported for now. Got {expr_or_fun.data}"
)
else:
return self.call(expr_or_fun=expr_or_fun, step=step)
raise ValueError(f"expr_or_fun should be a Parameter. Got {expr_or_fun}")
# return self.call(expr_or_fun=expr_or_fun, step=step)

def execute_func(
self, func: Union[Function, Parameter]
Expand All @@ -257,48 +303,8 @@ def execute_func(

if isinstance(func, Parameter):

class GetFunctionTool(GradComponent):
def __init__(self):
super().__init__()

def forward(self, func: Parameter, context: Dict[str, object]):
return self.bicall(func, context=context)

def call(self, func: FunctionOutput, context: Dict[str, object]):
return self.bicall(func, context=context)

def bicall(
self,
func: Union[FunctionOutput, Parameter],
context: Dict[str, object] = {},
):
if isinstance(func, Parameter):
printc(f"context: {context}", color="yellow")
tool: FunctionTool = context[func.data.name]
print(f"tool training: {tool.training}")
output = tool.forward(*func.data.args, **func.data.kwargs)
# handle the untainable function
if not isinstance(output, Parameter):
# warnings.info(
# f"Error executing function: {output}", UserWarning
# )
output = OutputParameter(
name=func.data.name,
data=output,
requires_opt=False,
param_type=ParameterType.OUTPUT,
)
return output

output.predecessors.add(func)
return output
else:
tool: FunctionTool = context[func.name]
output = tool.call(*func.args, **func.kwargs)
return output

tool = GetFunctionTool()
return tool.forward(func, context=self.context)
call_func_tool = CallFunctionTool()
return call_func_tool.forward(func, context=self.context)
else:
try:
tool: FunctionTool = self.context[func.name]
Expand Down
Loading

0 comments on commit 2f19a14

Please sign in to comment.