Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat declare input in wrapper #45

Merged
merged 6 commits into from
Sep 6, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions src/autora/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def _append(a: List[T], b: T) -> List[T]:
return a + [b]


def inputs_from_state(f):
def inputs_from_state(f, input_mapping: Dict = {}):
"""Decorator to make target `f` into a function on a `State` and `**kwargs`.

This wrapper makes it easier to pass arguments to a function from a State.
Expand All @@ -732,6 +732,7 @@ def inputs_from_state(f):
Args:
f: a function with arguments that could be fields on a `State`
and that returns a `Delta`.
input_mapping: a dict that maps the input arguments of the function to the state fields

Returns: a version of `f` which takes and returns `State` objects.

Expand All @@ -758,6 +759,22 @@ def inputs_from_state(f):
>>> experimentalist(U(conditions=[101,102,103,104]))
[111, 112, 113, 114]

If our function uses a different keyword argument then the state field, we can use
younesStrittmatter marked this conversation as resolved.
Show resolved Hide resolved
the input mapping:
>>> def experimentalist_(X):
... new_conditions = [x + 10 for x in X]
... return new_conditions
>>> experimentalist_on_state = inputs_from_state(experimentalist_, {'X': 'conditions'})
>>> experimentalist(U(conditions=[1,2,3,4]))
benwandrew marked this conversation as resolved.
Show resolved Hide resolved
[11, 12, 13, 14]

Both also work with the `State` as UserDict. Here, we use the StandardState
>>> experimentalist(StandardState(conditions=[1, 2, 3, 4]))
[11, 12, 13, 14]

>>> experimentalist_on_state(StandardState(conditions=[1, 2, 3, 4]))
[11, 12, 13, 14]

A dictionary can be returned and used:
>>> @inputs_from_state
... def returns_a_dictionary(conditions):
Expand Down Expand Up @@ -831,6 +848,14 @@ def inputs_from_state(f):
>>> experimentalist(U(conditions=[1,2,3,4]), offset=2)
[3, 4, 5, 6]

The same is true, if we don't provide a mapping for arguments:
>>> def experimentalist_(X, offset):
... new_conditions = [x + offset for x in X]
... return new_conditions
>>> experimentalist_on_state = inputs_from_state(experimentalist_, {'X': 'conditions'})
>>> experimentalist_on_state(StandardState(conditions=[1,2,3,4]), offset=2)
[3, 4, 5, 6]

The state itself is passed through if the inner function requests the `state`:
>>> @inputs_from_state
... def function_which_needs_whole_state(state, conditions):
Expand All @@ -843,7 +868,16 @@ def inputs_from_state(f):

"""
# Get the set of parameter names from function f's signature

reversed_mapping = {v: k for k, v in input_mapping.items()}

parameters_ = set(inspect.signature(f).parameters.keys())
missing_func_params = set(input_mapping.keys()).difference(parameters_)
if missing_func_params:
raise ValueError(
f"The following keys in input_state_mapping are not parameters of the function: "
f"{missing_func_params}"
)

@wraps(f)
def _f(state_: S, /, **kwargs) -> S:
Expand All @@ -853,9 +887,21 @@ def _f(state_: S, /, **kwargs) -> S:
if is_dataclass(state_):
from_state = parameters_.intersection({i.name for i in fields(state_)})
arguments_from_state = {k: getattr(state_, k) for k in from_state}
from_state_input_mapping = {
reversed_mapping.get(f.name, f.name): getattr(state_, f.name)
for f in fields(state_)
if reversed_mapping.get(f.name, f.name) in parameters_
}
arguments_from_state.update(from_state_input_mapping)
elif isinstance(state_, UserDict):
from_state = parameters_.intersection(set(state_.keys()))
arguments_from_state = {k: state_[k] for k in from_state}
from_state_input_mapping = {
reversed_mapping.get(key, key): state_[key]
for key in state_.keys()
if reversed_mapping.get(key, key) in parameters_
}
arguments_from_state.update(from_state_input_mapping)
if "state" in parameters_:
arguments_from_state["state"] = state_
arguments = dict(arguments_from_state, **kwargs)
Expand Down Expand Up @@ -1134,7 +1180,9 @@ def _f(state_: S, **kwargs) -> S:


def on_state(
function: Optional[Callable] = None, output: Optional[Sequence[str]] = None
function: Optional[Callable] = None,
input_mapping: Dict = {},
output: Optional[Sequence[str]] = None,
):
"""Decorator (factory) to make target `function` into a function on a `State` and `**kwargs`.

Expand All @@ -1143,6 +1191,7 @@ def on_state(
Args:
function: the function to be wrapped
output: list specifying State field names for the return values of `function`
input_mapping: a dict that maps the keywords of the functions to the state fields

Returns:

Expand Down Expand Up @@ -1193,13 +1242,26 @@ def on_state(
>>> add_six(W(conditions=[1, 2, 3, 4]))
W(conditions=[7, 8, 9, 10])

You can also declare input to output mopping, if the keyword arguments of the functions
younesStrittmatter marked this conversation as resolved.
Show resolved Hide resolved
don't match the state fields:
>>> @on_state(input_mapping={'X': 'conditions'}, output=["conditions"])
... def add_six(X):
... return [x + 6 for x in X]

>>> add_six(W(conditions=[1, 2, 3, 4]))
W(conditions=[7, 8, 9, 10])

This also works on the StandardState or other States that are defined as UserDicts:
>>> add_six(StandardState(conditions=[1, 2, 3,4])).conditions
[7, 8, 9, 10]

"""

def decorator(f):
f_ = f
if output is not None:
f_ = outputs_to_delta(*output)(f_)
f_ = inputs_from_state(f_)
f_ = inputs_from_state(f_, input_mapping)
f_ = delta_to_state(f_)
return f_

Expand Down