Skip to content

Commit

Permalink
Merge pull request #45 from AutoResearch/feat-declare-input-in-wrapper
Browse files Browse the repository at this point in the history
Feat declare input in wrapper
  • Loading branch information
younesStrittmatter authored Sep 6, 2023
2 parents 743cb4a + f4c8ea2 commit c8ed1af
Showing 1 changed file with 65 additions and 3 deletions.
68 changes: 65 additions & 3 deletions src/autora/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def _append(a: List[T], b: T) -> List[T]:
return a + [b]


def inputs_from_state(f):
def inputs_from_state(f, input_mapping: Dict = {}):
"""Decorator to make target `f` into a function on a `State` and `**kwargs`.
This wrapper makes it easier to pass arguments to a function from a State.
Expand All @@ -732,6 +732,7 @@ def inputs_from_state(f):
Args:
f: a function with arguments that could be fields on a `State`
and that returns a `Delta`.
input_mapping: a dict that maps the input arguments of the function to the state fields
Returns: a version of `f` which takes and returns `State` objects.
Expand All @@ -758,6 +759,22 @@ def inputs_from_state(f):
>>> experimentalist(U(conditions=[101,102,103,104]))
[111, 112, 113, 114]
If our function uses a different keyword argument than the state field, we can use
the input mapping:
>>> def experimentalist_(X):
... new_conditions = [x + 10 for x in X]
... return new_conditions
>>> experimentalist_on_state = inputs_from_state(experimentalist_, {'X': 'conditions'})
>>> experimentalist_on_state(U(conditions=[1,2,3,4]))
[11, 12, 13, 14]
Both also work with the `State` as UserDict. Here, we use the StandardState
>>> experimentalist(StandardState(conditions=[1, 2, 3, 4]))
[11, 12, 13, 14]
>>> experimentalist_on_state(StandardState(conditions=[1, 2, 3, 4]))
[11, 12, 13, 14]
A dictionary can be returned and used:
>>> @inputs_from_state
... def returns_a_dictionary(conditions):
Expand Down Expand Up @@ -831,6 +848,14 @@ def inputs_from_state(f):
>>> experimentalist(U(conditions=[1,2,3,4]), offset=2)
[3, 4, 5, 6]
The same is true, if we don't provide a mapping for arguments:
>>> def experimentalist_(X, offset):
... new_conditions = [x + offset for x in X]
... return new_conditions
>>> experimentalist_on_state = inputs_from_state(experimentalist_, {'X': 'conditions'})
>>> experimentalist_on_state(StandardState(conditions=[1,2,3,4]), offset=2)
[3, 4, 5, 6]
The state itself is passed through if the inner function requests the `state`:
>>> @inputs_from_state
... def function_which_needs_whole_state(state, conditions):
Expand All @@ -843,7 +868,16 @@ def inputs_from_state(f):
"""
# Get the set of parameter names from function f's signature

reversed_mapping = {v: k for k, v in input_mapping.items()}

parameters_ = set(inspect.signature(f).parameters.keys())
missing_func_params = set(input_mapping.keys()).difference(parameters_)
if missing_func_params:
raise ValueError(
f"The following keys in input_state_mapping are not parameters of the function: "
f"{missing_func_params}"
)

@wraps(f)
def _f(state_: S, /, **kwargs) -> S:
Expand All @@ -853,9 +887,21 @@ def _f(state_: S, /, **kwargs) -> S:
if is_dataclass(state_):
from_state = parameters_.intersection({i.name for i in fields(state_)})
arguments_from_state = {k: getattr(state_, k) for k in from_state}
from_state_input_mapping = {
reversed_mapping.get(f.name, f.name): getattr(state_, f.name)
for f in fields(state_)
if reversed_mapping.get(f.name, f.name) in parameters_
}
arguments_from_state.update(from_state_input_mapping)
elif isinstance(state_, UserDict):
from_state = parameters_.intersection(set(state_.keys()))
arguments_from_state = {k: state_[k] for k in from_state}
from_state_input_mapping = {
reversed_mapping.get(key, key): state_[key]
for key in state_.keys()
if reversed_mapping.get(key, key) in parameters_
}
arguments_from_state.update(from_state_input_mapping)
if "state" in parameters_:
arguments_from_state["state"] = state_
arguments = dict(arguments_from_state, **kwargs)
Expand Down Expand Up @@ -1134,7 +1180,9 @@ def _f(state_: S, **kwargs) -> S:


def on_state(
function: Optional[Callable] = None, output: Optional[Sequence[str]] = None
function: Optional[Callable] = None,
input_mapping: Dict = {},
output: Optional[Sequence[str]] = None,
):
"""Decorator (factory) to make target `function` into a function on a `State` and `**kwargs`.
Expand All @@ -1143,6 +1191,7 @@ def on_state(
Args:
function: the function to be wrapped
output: list specifying State field names for the return values of `function`
input_mapping: a dict that maps the keywords of the functions to the state fields
Returns:
Expand Down Expand Up @@ -1193,13 +1242,26 @@ def on_state(
>>> add_six(W(conditions=[1, 2, 3, 4]))
W(conditions=[7, 8, 9, 10])
You can also declare an input-to-output mapping if the keyword arguments of the functions
don't match the state fields:
>>> @on_state(input_mapping={'X': 'conditions'}, output=["conditions"])
... def add_six(X):
... return [x + 6 for x in X]
>>> add_six(W(conditions=[1, 2, 3, 4]))
W(conditions=[7, 8, 9, 10])
This also works on the StandardState or other States that are defined as UserDicts:
>>> add_six(StandardState(conditions=[1, 2, 3,4])).conditions
[7, 8, 9, 10]
"""

def decorator(f):
f_ = f
if output is not None:
f_ = outputs_to_delta(*output)(f_)
f_ = inputs_from_state(f_)
f_ = inputs_from_state(f_, input_mapping)
f_ = delta_to_state(f_)
return f_

Expand Down

0 comments on commit c8ed1af

Please sign in to comment.