Skip to content

Commit

Permalink
Merge pull request #63 from AutoResearch/chore/remove-aliases-from-state
Browse files Browse the repository at this point in the history
refactor!: remove aliases from state
  • Loading branch information
hollandjg authored Dec 1, 2023
2 parents c5b12a7 + b5ab07c commit 66d03e8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 187 deletions.
190 changes: 6 additions & 184 deletions src/autora/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,47 +212,6 @@ class State:
We can define aliases which can transform between different potential field
names.
>>> @dataclass(frozen=True)
... class FieldAliasState(State):
... things: List[str] = field(
... default_factory=list,
... metadata={"delta": "extend",
... "aliases": {"thing": lambda m: [m]}}
... )
In the "normal" case, the Delta object is expected to include a list of data in the
correct format which is used to extend the object:
>>> FieldAliasState(things=["0"]) + Delta(things=["1", "2"])
FieldAliasState(things=['0', '1', '2'])
However, say the standard return from a step in AER is a single `thing`, rather than a
sequence of them:
>>> FieldAliasState(things=["0"]) + Delta(thing="1")
FieldAliasState(things=['0', '1'])
If a cycle function relies on the existence of the `s.thing` as a property of your state
`s`, rather than accessing `s.things[-1]`, then you could additionally define a `property`:
>>> class FieldAliasStateWithProperty(FieldAliasState): # inherit from FieldAliasState
... @property
... def thing(self):
... return self.things[-1]
Now you can access both `s.things` and `s.thing` as required by your code. The State only
shows `things` in the string representation...
>>> u = FieldAliasStateWithProperty(things=["0"]) + Delta(thing="1")
>>> u
FieldAliasStateWithProperty(things=['0', '1'])
... and exposes `things` as an attribute:
>>> u.things
['0', '1']
... but also exposes `thing`, always returning the last value.
>>> u.thing
'1'
"""

def __add__(self, other: Union[Delta, Mapping]):
Expand Down Expand Up @@ -322,12 +281,7 @@ def _get_value(f, other: Union[Delta, Mapping]):
>>> from dataclasses import field, dataclass, fields
>>> @dataclass
... class Example:
... a: int = field() # base case
... b: List[int] = field(metadata={"aliases": {"ba": lambda b: [b]}}) # Single alias
... c: List[int] = field(metadata={"aliases": {
... "ca": lambda x: x, # pass the value unchanged
... "cb": lambda x: [x] # wrap the value in a list
... }}) # Multiple alias
... a: int = field()
For a field with no aliases, we retrieve values with the base name:
>>> f_a = fields(Example)[0]
Expand All @@ -342,104 +296,15 @@ def _get_value(f, other: Union[Delta, Mapping]):
>>> _get_value(f_a, Delta(b=2, a=1))
(1, 'a')
For fields with an alias, we retrieve values with the base name:
>>> f_b = fields(Example)[1]
>>> _get_value(f_b, Delta(b=[2]))
([2], 'b')
... or for the alias name, transformed by the alias lambda function:
>>> _get_value(f_b, Delta(ba=21))
([21], 'ba')
We preferentially get the base name, and then any aliases:
>>> _get_value(f_b, Delta(b=2, ba=21))
(2, 'b')
... , regardless of their order in the `Delta` object:
>>> _get_value(f_b, Delta(ba=21, b=2))
(2, 'b')
Other names are ignored:
>>> _get_value(f_b, Delta(a=1))
(None, None)
and the order of other names is unimportant:
>>> _get_value(f_b, Delta(a=1, b=2))
(2, 'b')
For fields with multiple aliases, we retrieve values with the base name:
>>> f_c = fields(Example)[2]
>>> _get_value(f_c, Delta(c=[3]))
([3], 'c')
... for any alias:
>>> _get_value(f_c, Delta(ca=31))
(31, 'ca')
... transformed by the alias lambda function :
>>> _get_value(f_c, Delta(cb=32))
([32], 'cb')
... and ignoring any other names:
>>> print(_get_value(f_c, Delta(a=1)))
(None, None)
... preferentially in the order base name, 1st alias, 2nd alias, ... nth alias:
>>> _get_value(f_c, Delta(c=3, ca=31, cb=32))
(3, 'c')
>>> _get_value(f_c, Delta(ca=31, cb=32))
(31, 'ca')
>>> _get_value(f_c, Delta(cb=32))
([32], 'cb')
>>> print(_get_value(f_c, Delta()))
(None, None)
This works with dict objects:
>>> _get_value(f_a, dict(a=13))
(13, 'a')
... with multiple keys:
>>> _get_value(f_b, dict(a=13, b=24, c=35))
(24, 'b')
... and with aliases:
>>> _get_value(f_b, dict(ba=222))
([222], 'ba')
This works with UserDicts:
>>> class MyDelta(UserDict):
... pass
>>> _get_value(f_a, MyDelta(a=14))
(14, 'a')
... with multiple keys:
>>> _get_value(f_b, MyDelta(a=1, b=4, c=9))
(4, 'b')
... and with aliases:
>>> _get_value(f_b, MyDelta(ba=234))
([234], 'ba')
"""

key = f.name
aliases = f.metadata.get("aliases", {})

value, used_key = None, None

if key in other.keys():
value = other[key]
used_key = key
elif aliases: # ... is not an empty dict
for alias_key, wrapping_function in aliases.items():
if alias_key in other:
value = wrapping_function(other[alias_key])
used_key = alias_key
break # we only evaluate the first match

return value, used_key

Expand All @@ -462,23 +327,8 @@ def _get_field_names_and_aliases(s: State):
>>> _get_field_names_and_aliases(SomeState())
['l', 'm']
>>> @dataclass(frozen=True)
... class SomeStateWithAliases(State):
... l: List = field(default_factory=list, metadata={"aliases": {"l1": None, "l2": None}})
... m: List = field(default_factory=list, metadata={"aliases": {"m1": None}})
>>> _get_field_names_and_aliases(SomeStateWithAliases())
['l', 'l1', 'l2', 'm', 'm1']
"""
result = []

for f in fields(s):
name = f.name
result.append(name)

aliases = f.metadata.get("aliases", {})
result.extend(aliases)

result = [f.name for f in fields(s)]
return result


Expand Down Expand Up @@ -1322,27 +1172,6 @@ class StandardState(State):
>>> (s + dm1 + dm2).models
[DummyClassifier(constant=1), DummyClassifier(constant=2), DummyClassifier(constant=3)]
The last model is available under the `model` property:
>>> (s + dm1 + dm2).model
DummyClassifier(constant=3)
If there is no model, `None` is returned:
>>> print(s.model)
None
`models` can also be updated using a Delta with a single `model`:
>>> dm3 = Delta(model=DummyClassifier(constant=4))
>>> (s + dm1 + dm3).model
DummyClassifier(constant=4)
As before, the `models` list is extended:
>>> (s + dm1 + dm3).models
[DummyClassifier(constant=1), DummyClassifier(constant=4)]
No coercion or validation occurs with `models` or `model`:
>>> (s + dm1 + Delta(model="not a model")).models
[DummyClassifier(constant=1), 'not a model']
"""

variables: Optional[VariableCollection] = field(
Expand All @@ -1356,17 +1185,9 @@ class StandardState(State):
)
models: List[BaseEstimator] = field(
default_factory=list,
metadata={"delta": "extend", "aliases": {"model": lambda model: [model]}},
metadata={"delta": "extend"},
)

@property
def model(self):
"""Alias for the last model in the `models`."""
try:
return self.models[-1]
except IndexError:
return None


X = TypeVar("X")
Y = TypeVar("Y")
Expand Down Expand Up @@ -1396,8 +1217,9 @@ def estimator_on_state(estimator: BaseEstimator) -> StateFunction:
... experiment_data=pd.DataFrame({"x": [1,2,3], "y":[3,6,9]})
... )
Run the function, which fits the model and adds the result to the `StandardState`
>>> state_fn(s).model.coef_
Run the function, which fits the model and adds the result to the `StandardState` as the
last entry in the .models list.
>>> state_fn(s).models[-1].coef_
array([[3.]])
"""
Expand Down
6 changes: 3 additions & 3 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def validate_model(state: Optional[StandardState]):
assert state.experiment_data is not None
assert len(state.experiment_data) == 100

assert state.model is not None
assert np.allclose(state.model.coef_, [[2.0]])
assert np.allclose(state.model.intercept_, [[0.5]])
assert state.models[-1] is not None
assert np.allclose(state.models[-1].coef_, [[2.0]])
assert np.allclose(state.models[-1].intercept_, [[0.5]])


@given(example_workflow_library_module)
Expand Down

0 comments on commit 66d03e8

Please sign in to comment.