diff --git a/src/autora/state.py b/src/autora/state.py index a267c436..99740384 100644 --- a/src/autora/state.py +++ b/src/autora/state.py @@ -212,47 +212,6 @@ class State: We can define aliases which can transform between different potential field names. - >>> @dataclass(frozen=True) - ... class FieldAliasState(State): - ... things: List[str] = field( - ... default_factory=list, - ... metadata={"delta": "extend", - ... "aliases": {"thing": lambda m: [m]}} - ... ) - - In the "normal" case, the Delta object is expected to include a list of data in the - correct format which is used to extend the object: - >>> FieldAliasState(things=["0"]) + Delta(things=["1", "2"]) - FieldAliasState(things=['0', '1', '2']) - - However, say the standard return from a step in AER is a single `thing`, rather than a - sequence of them: - >>> FieldAliasState(things=["0"]) + Delta(thing="1") - FieldAliasState(things=['0', '1']) - - - If a cycle function relies on the existence of the `s.thing` as a property of your state - `s`, rather than accessing `s.things[-1]`, then you could additionally define a `property`: - - >>> class FieldAliasStateWithProperty(FieldAliasState): # inherit from FieldAliasState - ... @property - ... def thing(self): - ... return self.things[-1] - - Now you can access both `s.things` and `s.thing` as required by your code. The State only - shows `things` in the string representation... - >>> u = FieldAliasStateWithProperty(things=["0"]) + Delta(thing="1") - >>> u - FieldAliasStateWithProperty(things=['0', '1']) - - ... and exposes `things` as an attribute: - >>> u.things - ['0', '1'] - - ... but also exposes `thing`, always returning the last value. - >>> u.thing - '1' - """ def __add__(self, other: Union[Delta, Mapping]): @@ -322,12 +281,7 @@ def _get_value(f, other: Union[Delta, Mapping]): >>> from dataclasses import field, dataclass, fields >>> @dataclass ... class Example: - ... a: int = field() # base case - ... b: List[int] = field(metadata={"aliases": {"ba": lambda b: [b]}}) # Single alias - ... c: List[int] = field(metadata={"aliases": { - ... "ca": lambda x: x, # pass the value unchanged - ... "cb": lambda x: [x] # wrap the value in a list - ... }}) # Multiple alias + ... a: int = field() For a field with no aliases, we retrieve values with the base name: >>> f_a = fields(Example)[0] @@ -342,104 +296,15 @@ def _get_value(f, other: Union[Delta, Mapping]): >>> _get_value(f_a, Delta(b=2, a=1)) (1, 'a') - For fields with an alias, we retrieve values with the base name: - >>> f_b = fields(Example)[1] - >>> _get_value(f_b, Delta(b=[2])) - ([2], 'b') - - ... or for the alias name, transformed by the alias lambda function: - >>> _get_value(f_b, Delta(ba=21)) - ([21], 'ba') - - We preferentially get the base name, and then any aliases: - >>> _get_value(f_b, Delta(b=2, ba=21)) - (2, 'b') - - ... , regardless of their order in the `Delta` object: - >>> _get_value(f_b, Delta(ba=21, b=2)) - (2, 'b') - - Other names are ignored: - >>> _get_value(f_b, Delta(a=1)) - (None, None) - - and the order of other names is unimportant: - >>> _get_value(f_b, Delta(a=1, b=2)) - (2, 'b') - - For fields with multiple aliases, we retrieve values with the base name: - >>> f_c = fields(Example)[2] - >>> _get_value(f_c, Delta(c=[3])) - ([3], 'c') - - ... for any alias: - >>> _get_value(f_c, Delta(ca=31)) - (31, 'ca') - - ... transformed by the alias lambda function : - >>> _get_value(f_c, Delta(cb=32)) - ([32], 'cb') - - ... and ignoring any other names: - >>> print(_get_value(f_c, Delta(a=1))) - (None, None) - - ... preferentially in the order base name, 1st alias, 2nd alias, ... nth alias: - >>> _get_value(f_c, Delta(c=3, ca=31, cb=32)) - (3, 'c') - - >>> _get_value(f_c, Delta(ca=31, cb=32)) - (31, 'ca') - - >>> _get_value(f_c, Delta(cb=32)) - ([32], 'cb') - - >>> print(_get_value(f_c, Delta())) - (None, None) - - This works with dict objects: - >>> _get_value(f_a, dict(a=13)) - (13, 'a') - - ... with multiple keys: - >>> _get_value(f_b, dict(a=13, b=24, c=35)) - (24, 'b') - - ... and with aliases: - >>> _get_value(f_b, dict(ba=222)) - ([222], 'ba') - - This works with UserDicts: - >>> class MyDelta(UserDict): - ... pass - - >>> _get_value(f_a, MyDelta(a=14)) - (14, 'a') - - ... with multiple keys: - >>> _get_value(f_b, MyDelta(a=1, b=4, c=9)) - (4, 'b') - - ... and with aliases: - >>> _get_value(f_b, MyDelta(ba=234)) - ([234], 'ba') - """ key = f.name - aliases = f.metadata.get("aliases", {}) value, used_key = None, None if key in other.keys(): value = other[key] used_key = key - elif aliases: # ... is not an empty dict - for alias_key, wrapping_function in aliases.items(): - if alias_key in other: - value = wrapping_function(other[alias_key]) - used_key = alias_key - break # we only evaluate the first match return value, used_key @@ -462,23 +327,8 @@ def _get_field_names_and_aliases(s: State): >>> _get_field_names_and_aliases(SomeState()) ['l', 'm'] - >>> @dataclass(frozen=True) - ... class SomeStateWithAliases(State): - ... l: List = field(default_factory=list, metadata={"aliases": {"l1": None, "l2": None}}) - ... m: List = field(default_factory=list, metadata={"aliases": {"m1": None}}) - >>> _get_field_names_and_aliases(SomeStateWithAliases()) - ['l', 'l1', 'l2', 'm', 'm1'] - """ - result = [] - - for f in fields(s): - name = f.name - result.append(name) - - aliases = f.metadata.get("aliases", {}) - result.extend(aliases) - + result = [f.name for f in fields(s)] return result @@ -1322,27 +1172,6 @@ class StandardState(State): >>> (s + dm1 + dm2).models [DummyClassifier(constant=1), DummyClassifier(constant=2), DummyClassifier(constant=3)] - The last model is available under the `model` property: - >>> (s + dm1 + dm2).model - DummyClassifier(constant=3) - - If there is no model, `None` is returned: - >>> print(s.model) - None - - `models` can also be updated using a Delta with a single `model`: - >>> dm3 = Delta(model=DummyClassifier(constant=4)) - >>> (s + dm1 + dm3).model - DummyClassifier(constant=4) - - As before, the `models` list is extended: - >>> (s + dm1 + dm3).models - [DummyClassifier(constant=1), DummyClassifier(constant=4)] - - No coercion or validation occurs with `models` or `model`: - >>> (s + dm1 + Delta(model="not a model")).models - [DummyClassifier(constant=1), 'not a model'] - """ variables: Optional[VariableCollection] = field( @@ -1356,17 +1185,9 @@ class StandardState(State): ) models: List[BaseEstimator] = field( default_factory=list, - metadata={"delta": "extend", "aliases": {"model": lambda model: [model]}}, + metadata={"delta": "extend"}, ) - @property - def model(self): - """Alias for the last model in the `models`.""" - try: - return self.models[-1] - except IndexError: - return None - X = TypeVar("X") Y = TypeVar("Y") @@ -1396,8 +1217,9 @@ def estimator_on_state(estimator: BaseEstimator) -> StateFunction: ... experiment_data=pd.DataFrame({"x": [1,2,3], "y":[3,6,9]}) ... ) - Run the function, which fits the model and adds the result to the `StandardState` - >>> state_fn(s).model.coef_ + Run the function, which fits the model and adds the result to the `StandardState` as the + last entry in the .models list. + >>> state_fn(s).models[-1].coef_ array([[3.]]) """ diff --git a/tests/test_workflow.py b/tests/test_workflow.py index a1417c7b..ba922639 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -25,9 +25,9 @@ def validate_model(state: Optional[StandardState]): assert state.experiment_data is not None assert len(state.experiment_data) == 100 - assert state.model is not None - assert np.allclose(state.model.coef_, [[2.0]]) - assert np.allclose(state.model.intercept_, [[0.5]]) + assert state.models[-1] is not None + assert np.allclose(state.models[-1].coef_, [[2.0]]) + assert np.allclose(state.models[-1].intercept_, [[0.5]]) @given(example_workflow_library_module)