Skip to content

Commit

Permalink
Merge pull request #87 from AutoResearch/83-add-alias-for-x-and-y-in-…
Browse files Browse the repository at this point in the history
…standardstate-as-pandas-dataframe

feat: make wrappers work on added properties
  • Loading branch information
younesStrittmatter authored Aug 10, 2024
2 parents f107389 + 082180f commit f565ae3
Showing 1 changed file with 70 additions and 7 deletions.
77 changes: 70 additions & 7 deletions src/autora/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

_logger = logging.getLogger(__name__)


T = TypeVar("T")
C = TypeVar("C", covariant=True)

Expand Down Expand Up @@ -356,6 +355,38 @@ def _get_value(f, other: Union[Delta, Mapping]):
return value, used_key


def _get_field_names_and_properties(s: State):
"""
Get a list of field names and their aliases from a State object
Args:
s: a State object
Returns: a list of field names and their aliases on `s`
Examples:
>>> from dataclasses import field
>>> @dataclass(frozen=True)
... class SomeState(State):
... l: List = field(default_factory=list)
... m: List = field(default_factory=list)
... @property
... def both(self):
... return self.l + self.m
>>> _get_field_names_and_properties(SomeState())
['both', 'l', 'm']
"""
result = _get_field_names_and_aliases(s)
property_names = [
attr
for attr in dir(s)
if isinstance(getattr(type(s), attr, None), property)
and attr not in dir(object)
and attr not in result
]
return property_names + result


def _get_field_names_and_aliases(s: State):
"""
Get a list of field names and their aliases from a State object
Expand Down Expand Up @@ -692,19 +723,21 @@ def inputs_from_state(f, input_mapping: Dict = {}):
)

@wraps(f)
def _f(state_: S, /, **kwargs) -> S:
def _f(state_: State, /, **kwargs) -> State:
# Get the parameters needed which are available from the state_.
# All others must be provided as kwargs or default values on f.
assert is_dataclass(state_) or isinstance(state_, UserDict)
if is_dataclass(state_):
from_state = parameters_.intersection({i.name for i in fields(state_)})
from_state = parameters_.intersection(
_get_field_names_and_properties(state_)
)
arguments_from_state = {k: getattr(state_, k) for k in from_state}
from_state_input_mapping = {
reversed_mapping.get(field.name, field.name): getattr(
state_, field.name
reversed_mapping.get(field_name, field_name): getattr(
state_, field_name
)
for field in fields(state_)
if reversed_mapping.get(field.name, field.name) in parameters_
for field_name in _get_field_names_and_properties(state_)
if reversed_mapping.get(field_name, field_name) in parameters_
}
arguments_from_state.update(from_state_input_mapping)
elif isinstance(state_, UserDict):
Expand Down Expand Up @@ -1219,6 +1252,36 @@ class StandardState(State):
>>> (s + dm1 + dm2).models
[DummyClassifier(constant=1), DummyClassifier(constant=2), DummyClassifier(constant=3)]
We can use properties X, y, iv_names and dv_names as 'getters' ...
>>> x_v = Variable('x')
>>> y_v = Variable('y')
>>> variables = VariableCollection(independent_variables=[x_v], dependent_variables=[y_v])
>>> e_data = pd.DataFrame({'x': [1, 2, 3], 'y': [2, 4, 6]})
>>> s = StandardState(variables=variables, experiment_data=e_data)
>>> @inputs_from_state
... def show_X(X):
... return X
>>> show_X(s)
x
0 1
1 2
2 3
... but nothing happens if we use them as `setters`:
>>> @on_state
... def add_to_X(X):
... res = X.copy()
... res['x'] += 1
... return Delta(X=res)
>>> s = add_to_X(s)
>>> s.X
x
0 1
1 2
2 3
"""

variables: Optional[VariableCollection] = field(
Expand Down

0 comments on commit f565ae3

Please sign in to comment.