From 2f19a14d5726b900d0d3d76a681d3423324ffed9 Mon Sep 17 00:00:00 2001
From: Li Yin
Date: Fri, 27 Dec 2024 19:39:46 -0800
Subject: [PATCH] clean up code in functool and add decorator to quickly handle
the gradient pass through grad component correction, the code follows the 3
cases of gradcomponent design now, still need to further optimize the agent
---
adalflow/adalflow/components/agent/react.py | 17 +-
adalflow/adalflow/core/func_tool.py | 8 +
adalflow/adalflow/core/tool_manager.py | 146 +++++++++---------
adalflow/adalflow/optim/grad_component.py | 136 ++++++++++++++--
adalflow/adalflow/optim/parameter.py | 4 +-
adalflow/adalflow/optim/trainer/trainer.py | 2 +-
.../hotpot_qa/adal_exp/train_agent_rag.py | 6 +-
7 files changed, 224 insertions(+), 95 deletions(-)
diff --git a/adalflow/adalflow/components/agent/react.py b/adalflow/adalflow/components/agent/react.py
index e004332e..e9af5b99 100644
--- a/adalflow/adalflow/components/agent/react.py
+++ b/adalflow/adalflow/components/agent/react.py
@@ -133,7 +133,7 @@ def call(
self,
action_str: FunctionExpression,
step: int,
- result: Union[FunctionOutput, Parameter],
+ result: FunctionOutput,
func: Function,
id: Optional[str] = None,
) -> StepOutput:
@@ -143,14 +143,8 @@ def call(
raise ValueError(f"Expected FunctionExpression, but got {type(action_str)}")
step_output.action = action_str
step_output.function = func
- # printc(f"result: {result}", color="blue")
- result = result.data if isinstance(result, Parameter) else result
- if isinstance(result, FunctionOutput):
- step_output.observation = (
- result.output.data
- if isinstance(result.output, Parameter)
- else result.output
- )
+
+ step_output.observation = result.output
return step_output
@@ -365,8 +359,6 @@ def _execute_action(
if isinstance(func, Parameter):
func.data.kwargs["id"] = id
- func.add_successor_map_fn(self.tool_manager, lambda x: x.data)
-
result: Parameter = self.tool_manager(expr_or_fun=func, step="execute")
# printc(f"result: {result}", color="red")
result.add_successor_map_fn(
@@ -375,6 +367,9 @@ def _execute_action(
response.add_successor_map_fn(
successor=function_output_to_step_output, map_fn=lambda x: x.data
)
+ func.add_successor_map_fn(
+ successor=function_output_to_step_output, map_fn=lambda x: x.data
+ )
action_step = function_output_to_step_output.forward(
action_str=response,
step=action_step.step,
diff --git a/adalflow/adalflow/core/func_tool.py b/adalflow/adalflow/core/func_tool.py
index 66a33c6b..ad135b74 100644
--- a/adalflow/adalflow/core/func_tool.py
+++ b/adalflow/adalflow/core/func_tool.py
@@ -107,6 +107,7 @@ def retriever_as_tool(input: str) -> str:
- via sandboxed execute directionly using ``sandbox_exec``.
+ A FunctionTool allows other GradComponent(as a tool) to pass through correctly.
"""
def __init__(
@@ -237,6 +238,9 @@ def sync_function_1():
if self._is_async:
raise ValueError("FunctionTool is asynchronous, use acall instead")
output, error = None, None
+
+ # NOTE: special case:
+ # self.fn can have both train and eval mode or untrainable as a function.
try:
output = self.fn(*args, **kwargs)
except Exception as e:
@@ -247,6 +251,10 @@ def sync_function_1():
print(f"typeof output: {type(output)}")
if isinstance(output, Parameter):
+ if not self.training:
+ raise ValueError(
+ f"FunctionTool {self.definition.func_name} is in eval mode, but the output is Parameter"
+ )
print("output is Parameter")
output.data = FunctionOutput(
name=self.definition.func_name,
diff --git a/adalflow/adalflow/core/tool_manager.py b/adalflow/adalflow/core/tool_manager.py
index 634dfb1d..1db303e4 100644
--- a/adalflow/adalflow/core/tool_manager.py
+++ b/adalflow/adalflow/core/tool_manager.py
@@ -16,7 +16,7 @@
import logging
from copy import deepcopy
import asyncio
-from adalflow.optim.parameter import Parameter, ParameterType, OutputParameter
+from adalflow.optim.parameter import Parameter, ParameterType
import nest_asyncio
import warnings
@@ -57,6 +57,71 @@ def run_async_in_new_loop(coro):
asyncio.set_event_loop(None)
+class CallFunctionTool(GradComponent):
+ __doc__ = """Contains other unit gradcomponent such as calling
+ a FunctionTool"""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, func: Parameter, context: Dict[str, object]):
+ return self.bicall(func, context=context)
+
+ def call(self, func: Function, context: Dict[str, object]) -> FunctionOutput:
+ return self.bicall(func, context=context)
+
+ def bicall(
+ self,
+ func: Union[Function, Parameter],
+ context: Dict[str, object] = {},
+ ):
+ if isinstance(func, Parameter):
+ # data = func.successor_map_fn(func)
+ printc(f"context: {context}", color="yellow")
+ tool: FunctionTool = context[func.data.name]
+ print(f"tool training: {tool.training}")
+ output = tool.forward(*func.data.args, **func.data.kwargs)
+
+ from adalflow.optim.grad_component import fun_to_grad_component
+
+ # this will automatically create the outputparam, and connect output, func to the outputParam
+ @fun_to_grad_component
+ def dummy_pass_through_for_untrainable_fn(output, func):
+ return output
+
+ # NOTE: special case: handle the function which is not a grad_component
+ # here we have to specifically converts it to a parameter and handles the predecessors
+ # there is no trainable parameters inside of the tool but the tool response itself can be optimized by response optimizer
+ if not isinstance(output, Parameter):
+ return dummy_pass_through_for_untrainable_fn.forward(output, func)
+ else:
+ # reconnect the predecessor for tracing as it is not done in tool.forward
+ output.predecessors.add(func)
+ return output
+ else:
+ tool: FunctionTool = context[func.name]
+ output = tool.call(*func.args, **func.kwargs)
+ return output
+
+
+class FunctionExperssionToFunction(GradComponent):
+ def __init__(self):
+ super().__init__()
+
+ def call(self, expr: FunctionExpression, context: Dict[str, object]):
+ print("DummpyGradComponent call")
+ print(expr)
+
+ expr_str = expr.action
+ func_name, args, kwargs = parse_function_call_expr(expr_str, context)
+ return Function(
+ name=func_name,
+ args=args,
+ kwargs=kwargs,
+ thought=expr.thought,
+ )
+
+
# TODO: good to track all the failed function calls
# Tool manager is a task component
class ToolManager(GradComponent):
@@ -149,35 +214,10 @@ def parse_func_expr(
if isinstance(expr, Parameter):
try:
- class FunctionExperssionToFunction(GradComponent):
- def __init__(self):
- super().__init__()
-
- def call(
- self, expr: FunctionExpression, context: Dict[str, object]
- ):
- print("DummpyGradComponent call")
- print(expr)
-
- expr_str = expr.action
- func_name, args, kwargs = parse_function_call_expr(
- expr_str, context
- )
- return Function(
- name=func_name,
- args=args,
- kwargs=kwargs,
- thought=expr.thought,
- )
-
dummy = FunctionExperssionToFunction()
print("FunctionExperssionToFunction")
- # expr.add_successor_map_fn(dummy, map_fn=lambda x: x.data)
return dummy.forward(expr, context=self.context)
- # expr_str = expr.action
- # func_name, args, kwargs = parse_function_call_expr(expr_str, self.context)
- # return Function(name=func_name, args=args, kwargs=kwargs)
except Exception as e:
log.error(f"Error {e} parsing function call expression: {expr}")
raise ValueError(f"Error {e} parsing function call expression: {expr}")
@@ -218,6 +258,10 @@ def call(
expr_or_fun: Union[FunctionExpression, Function],
step: Literal["execute"] = "execute",
) -> Union[FunctionOutput, Function, Parameter]:
+ if not isinstance(expr_or_fun, (Function, FunctionExpression)):
+ raise ValueError(
+ f"expr_or_fun should be either a Function or FunctionExpression. Got {expr_or_fun}"
+ )
if step == "parse":
if isinstance(expr_or_fun, Function):
return expr_or_fun
@@ -229,8 +273,9 @@ def call(
def forward(
self,
+ *,
expr_or_fun: Union[FunctionExpression, Function, Parameter],
- step: str = "execute",
+ step: Literal["parse", "execute"] = "execute",
) -> Union[FunctionOutput, Function, Parameter]:
if isinstance(expr_or_fun, Parameter):
if step == "execute":
@@ -248,7 +293,8 @@ def forward(
f"Only function call expressions are supported for now. Got {expr_or_fun.data}"
)
else:
- return self.call(expr_or_fun=expr_or_fun, step=step)
+ raise ValueError(f"expr_or_fun should be a Parameter. Got {expr_or_fun}")
+ # return self.call(expr_or_fun=expr_or_fun, step=step)
def execute_func(
self, func: Union[Function, Parameter]
@@ -257,48 +303,8 @@ def execute_func(
if isinstance(func, Parameter):
- class GetFunctionTool(GradComponent):
- def __init__(self):
- super().__init__()
-
- def forward(self, func: Parameter, context: Dict[str, object]):
- return self.bicall(func, context=context)
-
- def call(self, func: FunctionOutput, context: Dict[str, object]):
- return self.bicall(func, context=context)
-
- def bicall(
- self,
- func: Union[FunctionOutput, Parameter],
- context: Dict[str, object] = {},
- ):
- if isinstance(func, Parameter):
- printc(f"context: {context}", color="yellow")
- tool: FunctionTool = context[func.data.name]
- print(f"tool training: {tool.training}")
- output = tool.forward(*func.data.args, **func.data.kwargs)
- # handle the untainable function
- if not isinstance(output, Parameter):
- # warnings.info(
- # f"Error executing function: {output}", UserWarning
- # )
- output = OutputParameter(
- name=func.data.name,
- data=output,
- requires_opt=False,
- param_type=ParameterType.OUTPUT,
- )
- return output
-
- output.predecessors.add(func)
- return output
- else:
- tool: FunctionTool = context[func.name]
- output = tool.call(*func.args, **func.kwargs)
- return output
-
- tool = GetFunctionTool()
- return tool.forward(func, context=self.context)
+ call_func_tool = CallFunctionTool()
+ return call_func_tool.forward(func, context=self.context)
else:
try:
tool: FunctionTool = self.context[func.name]
diff --git a/adalflow/adalflow/optim/grad_component.py b/adalflow/adalflow/optim/grad_component.py
index e6f5d110..0602f411 100644
--- a/adalflow/adalflow/optim/grad_component.py
+++ b/adalflow/adalflow/optim/grad_component.py
@@ -1,9 +1,10 @@
"""Base class for Autograd Components that can be called and backpropagated through."""
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable, Optional
from collections import OrderedDict
import uuid
import logging
+from copy import deepcopy
if TYPE_CHECKING:
from adalflow.core.generator import BackwardEngine
@@ -13,9 +14,10 @@
from adalflow.core.component import Component
from adalflow.optim.function import BackwardContext
+from adalflow.utils.registry import EntityMapping
-__all__ = ["GradComponent"]
+__all__ = ["GradComponent", "FunGradComponent", "fun_to_grad_component"]
log = logging.getLogger(__name__)
@@ -84,6 +86,9 @@ def forward(self, *args, **kwargs) -> "Parameter":
for idx, arg in enumerate(args):
input_args[f"arg_{idx}"] = arg
+ # Get data id from the kwargs
+ data_id = kwargs.get("id", None)
+
# Add keyword args to the ordered dict, preserving order
predecessors = []
for v in input_args.values():
@@ -91,11 +96,15 @@ def forward(self, *args, **kwargs) -> "Parameter":
predecessors.append(v)
if v.param_type == ParameterType.INPUT:
v.data_id = kwargs.get("id", None)
+ if data_id is None:
+ data_id = v.data_id
for v in kwargs.values():
if isinstance(v, Parameter):
predecessors.append(v)
if v.param_type == ParameterType.INPUT:
v.data_id = kwargs.get("id", None)
+ if data_id is None:
+ data_id = v.data_id
# 2. unwrap the parameter object to take only the data, successor_map_fn: lambda x: x.data in default
# unwrap args
@@ -123,7 +132,9 @@ def forward(self, *args, **kwargs) -> "Parameter":
call_response = self.call(*unwrapped_args, **unwrapped_kwargs)
if isinstance(call_response, Parameter):
- raise ValueError("A GradComponent call should not return Parameter")
+ raise ValueError(
+ f"A GradComponent call should not return Parameter, got {call_response.name}"
+ )
predecessors.append(call_response)
return call_response
@@ -138,20 +149,20 @@ def forward(self, *args, **kwargs) -> "Parameter":
name=self.name + "_output",
role_desc=self.name + " response",
param_type=ParameterType.OUTPUT,
- data_id=kwargs.get("id", None),
+ data_id=data_id,
)
response.set_predecessors(predecessors)
response.trace_forward_pass(
input_args=tracing_args,
full_response=call_response,
- id=self.id,
+ id=self.id, # this is component id
name=self.name,
)
response.set_grad_fn(
BackwardContext(
backward_fn=self.backward,
response=response,
- id=kwargs.get("id", None),
+ id=data_id,
)
)
return response
@@ -190,11 +201,13 @@ def backward(self, *, response: "Parameter", id: str = None, **kwargs):
# passing the successor's gradient.data to the current.
for grad in response.gradients:
- # make a copy of the gradient
- # grad = deepcopy(grad)
+ # NOTE: make a copy of the gradient, we should not modify the original gradient
+ grad = deepcopy(grad)
# update the gradient context and from and to
- grad.update_from_to(response, pred)
- grad.is_default_copy = True
+ # grad.update_from_to(response, pred)
+ grad.is_default_copy = (
+ True # response and pred will keep the original gradient
+ )
grad.add_context(
GradientContext(
variable_desc=pred.role_desc,
@@ -204,3 +217,106 @@ def backward(self, *, response: "Parameter", id: str = None, **kwargs):
)
pred.add_gradient(grad)
+
+
+class FunGradComponent(GradComponent):
+ r"""Wraps a function as a GradComponent.
+
+ Args:
+ fun (Callable): The function to be wrapped.
+
+ Examples:
+
+ function = lambda x: x + 1
+ fun_component = FunComponent(function)
+ print(fun_component(1)) # 2
+ """
+
+ def __init__(self, fun: Optional[Callable] = None, afun: Optional[Callable] = None):
+ super().__init__()
+ self.fun_name = fun.__name__
+ EntityMapping.register(self.fun_name, fun)
+
+ def call(self, *args, **kwargs):
+ fun = EntityMapping.get(self.fun_name)
+ return fun(*args, **kwargs)
+
+ def _extra_repr(self) -> str:
+ return super()._extra_repr() + f"fun_name={self.fun_name}"
+
+
+def fun_to_grad_component(fun) -> FunGradComponent:
+ r"""Helper function to convert a function into a Component with
+ its own class name.
+
+ Can be used as both a decorator and a function.
+
+ Args:
+ fun (Callable): The function to be wrapped.
+ Returns:
+ FunComponent: The component that wraps the function.
+
+ Examples:
+ 1. As a decorator:
+ >>> @fun_to_component
+ >>> def my_function(x):
+ >>> return x + 1
+ >>> # is equivalent to
+ >>> class MyFunctionComponent(FunComponent):
+ >>> def __init__(self):
+ >>> super().__init__(my_function)
+
+ 2. As a function:
+ >>> my_function_component = fun_to_component(my_function)
+ """
+
+ # Split the function name by underscores, capitalize each part, and join them back together
+ class_name = (
+ "".join(part.capitalize() for part in fun.__name__.split("_")) + "GradComponent"
+ )
+ # register the function
+ EntityMapping.register(fun.__name__, fun)
+ # Define a new component class dynamically
+ component_class = type(
+ class_name,
+ (FunGradComponent,),
+ {"__init__": lambda self: FunGradComponent.__init__(self, fun)},
+ )
+ # register the component
+ EntityMapping.register(class_name, component_class)
+
+ return component_class()
+
+
+if __name__ == "__main__":
+ # Test FunGradComponent
+ from adalflow.optim.parameter import Parameter
+
+ def my_function(x):
+ return x + 1
+
+ my_function_component = fun_to_grad_component(my_function)
+ print(my_function_component) # 2
+ # eval mode
+ output = my_function_component(1)
+ print(output)
+ # training mode
+ my_function_component.train()
+ output = my_function_component(Parameter(data=1, name="input"))
+ print(output)
+
+ # now test the decorator
+ @fun_to_grad_component
+ def my_function(x):
+ return x + 1
+
+ print(my_function(1))
+ # eval mode
+ output = my_function(1)
+ print(output)
+ assert output == 2
+
+ # training mode
+ my_function.train()
+ output = my_function(Parameter(data=1, name="input"))
+ print(output)
diff --git a/adalflow/adalflow/optim/parameter.py b/adalflow/adalflow/optim/parameter.py
index 29e59b93..161772c2 100644
--- a/adalflow/adalflow/optim/parameter.py
+++ b/adalflow/adalflow/optim/parameter.py
@@ -787,6 +787,8 @@ def generate_node_html(node: "Parameter", output_dir="node_pages"):
) # Use to_json_obj for proper JSON object structure
data_json = None
+ node_data_type = str(type(node.data)).replace("<", "<").replace(">", ">")
+ printc(f"Node data type: {node_data_type}")
if isinstance(node.data, dict):
data_json = data_json
elif isinstance(node.data, DataClass):
@@ -824,7 +826,7 @@ def generate_node_html(node: "Parameter", output_dir="node_pages"):
Details for Node: {node.name}
ID: {node.id}
Role: {node.role_desc}
- DataType: {str(type(node.data))}
+ DataType: {node_data_type}
Data: \n{json.dumps(data_json, indent=4)}
Data ID: {node.data_id}
Previous Value: {node.previous_data}
diff --git a/adalflow/adalflow/optim/trainer/trainer.py b/adalflow/adalflow/optim/trainer/trainer.py
index bd3e42c0..bf35a282 100644
--- a/adalflow/adalflow/optim/trainer/trainer.py
+++ b/adalflow/adalflow/optim/trainer/trainer.py
@@ -407,7 +407,7 @@ def fit(
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
- shuffle=True, # if not debug else False,
+ shuffle=True if not debug else False,
)
val_dataset = val_dataset or self.val_dataset
test_dataset = test_dataset or self.test_dataset
diff --git a/benchmarks/hotpot_qa/adal_exp/train_agent_rag.py b/benchmarks/hotpot_qa/adal_exp/train_agent_rag.py
index 4a8d0019..f8afbfa6 100644
--- a/benchmarks/hotpot_qa/adal_exp/train_agent_rag.py
+++ b/benchmarks/hotpot_qa/adal_exp/train_agent_rag.py
@@ -161,10 +161,12 @@ def train(
# )
train(
- debug=True,
- max_steps=12,
+ debug=False,
+ max_steps=8,
+ resume_from_ckpt="/Users/liyin/.adalflow/ckpt/AgenticRAGAdal/constrained_max_steps_4_dca7e_run_1.json",
)
# 0.68 on val without training, 0.74on the second step. 0.84 test
# /Users/liyin/.adalflow/ckpt/AgenticRAGAdal/constrained_max_steps_2_029cb_run_1.json
# 0.7, 0.72 /Users/liyin/.adalflow/ckpt/AgenticRAGAdal/constrained_max_steps_2_b7523_run_1.json
# 208.085706949234s, 2 steps, maximum 4 steps allow for an agent.
+ # 0.72->0.74, 4 steps, 366s, /Users/liyin/.adalflow/ckpt/AgenticRAGAdal/constrained_max_steps_4_dca7e_run_1.json [Already faster, still lots to optimize]