diff --git a/src/autora/state.py b/src/autora/state.py index 3c7f0f79..a267c436 100644 --- a/src/autora/state.py +++ b/src/autora/state.py @@ -31,6 +31,8 @@ from autora.variable import VariableCollection _logger = logging.getLogger(__name__) + + T = TypeVar("T") C = TypeVar("C", covariant=True) @@ -45,396 +47,8 @@ def __add__(self: C, other: Union[Delta, Mapping]) -> C: S = TypeVar("S", bound=DeltaAddable) -class StateDict(UserDict): - """ - Base object for UserDict which uses the Delta mechanism. - - Examples: - We first define an empty state - >>> s_0 = StateDict() - - Then we can add different fields with different Delta behaviours - >>> s_0.add_field("l", "extend", list("abc")) - >>> s_0.add_field("m", "replace", list("xyz")) - >>> s_0.l - ['a', 'b', 'c'] - >>> s_0.m - ['x', 'y', 'z'] - - We can add Deltas to it. Here, 'l' will be extended: - >>> s_1 = s_0 + Delta(l=list("def")) - >>> s_1.l - ['a', 'b', 'c', 'd', 'e', 'f'] - - ... whereas here, 'm' will be replaced: - >>> s_2 = s_1 + Delta(m=list("uvw")) - >>> s_2.m - ['u', 'v', 'w'] - - We can also chain Deltas: - >>> s_3 = s_2 + Delta(l=list("ghi")) + Delta(m=list("rst")) - >>> s_3.l - ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i'] - - >>> s_3.m - ['r', 's', 't'] - - ... or update multiple fields with one Delta: - >>> s_4 = s_3 + Delta(l=list("jkl"), m=list("opq")) - >>> s_4.l - ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l'] - - >>> s_4.m - ['o', 'p', 'q'] - - If we try to add a nonexistent field, nothing happens: - >>> s_5 = s_4 + Delta(n="not a field") - >>> 'n' in s_5 - False - - The update function replaces the entry: - >>> s_5.update(l=list("mno")) - >>> s_5.l - ['m', 'n', 'o'] - - We can also define fields which `append` the last result: - >>> s_5.add_field('n', 'append', list('abc')) - >>> s_6 = s_5 + Delta(n='d') - >>> s_6.n - ['a', 'b', 'c', 'd'] - - The metadata key "converter" is used to coerce types (inspired by - [PEP 712](https://peps.python.org/pep-0712/)): - >>> s_coerce = StateDict() - >>> s_coerce.add_field('o') - >>> s_coerce.add_field('p', converter=list) - >>> (s_coerce + Delta(o="not a list")).o - 'not a list' - - >>> (s_coerce + Delta(p='not a list')).p - ['n', 'o', 't', ' ', 'a', ' ', 'l', 'i', 's', 't'] - - If the input data are of the correct type, they are returned unaltered: - >>> (s_coerce + Delta(p=["a", "list"])).p - ['a', 'list'] - - With a converter, inputs are converted to the type that is output by the converter: - >>> s_coerce.add_field("q", converter=pd.DataFrame) - - If the type is already correct, the object is passed to the converter, - but should be returned unchanged: - >>> (s_coerce + Delta(q=pd.DataFrame([("a",1,"alpha"), ("b",2,"beta")],\ -columns=list("xyz")))).q - x y z - 0 a 1 alpha - 1 b 2 beta - - If the type is not correct, the object is converted if possible. For a DataFrame, - we can convert records: - >>> (s_coerce + Delta(q=[("a",1,"alpha"), ("b",2,"beta")])).q - 0 1 2 - 0 a 1 alpha - 1 b 2 beta - - ... or an array: - >>> (s_coerce + Delta(q=np.linspace([1, 2], [10, 15], 3))).q - 0 1 - 0 1.0 2.0 - 1 5.5 8.5 - 2 10.0 15.0 - - ... or a dictionary: - >>> (s_coerce + Delta(q={"a": [1,2,3], "b": [4,5,6]})).q - a b - 0 1 4 - 1 2 5 - 2 3 6 - - ... or a list: - >>> (s_coerce + Delta(q=[11, 12, 13])).q - 0 - 0 11 - 1 12 - 2 13 - - ... but not, for instance, a string: - >>> (s_coerce + Delta(q="not compatible with pd.DataFrame")).q - Traceback (most recent call last): - ... - ValueError: DataFrame constructor not properly called! - - We can define aliases for different potential field names: - >>> s_alias = StateDict() - >>> s_alias.add_field("things", "extend", aliases={"thing": lambda m: [m]}) - - - In the "normal" case, the Delta object is expected to include a list of data in the - format which is used to extend the object: - >>> s_alias = s_alias + Delta(things=["1", "2"]) - >>> s_alias.things - ['1', '2'] - - However, say the standard return from a step in AER is a single `thing`, rather than a - sequence: - >>> (s_alias + Delta(thing="3")).things - ['1', '2', '3'] - - If a cycle function relies on the existence of `s.thing` as a property of your state - `s`, rather than accessing `s.things[-1]`, you could additionally define a `getter`. - If you define such getters, the second argument must be a callable, in which case the input - to said callable will be interpreted as the state itself. - >>> s_alias.set_alias_getter("thing", lambda x: x["things"][-1]) - - At this point, you can access both `s.things` and `s.thing` as required by your code. - The State only shows `things` in the string representation. It exposes `things` as an - attribute: - >>> s_alias.things - ['1', '2'] - - ... but also exposes `thing`, which always returns the last value. - >>> s_alias.thing - '2' - - """ - - def __init__(self, data: Optional[Dict] = None): - super().__init__(data) - - def add_field( - self, name, delta="replace", default=None, aliases=None, converter=None - ): - self.data[name] = default - if "_metadata" not in self.data.keys(): - self.data["_metadata"] = {} - self.data["_metadata"][name] = {} - self.data["_metadata"][name]["default"] = default - self.data["_metadata"][name]["delta"] = delta - self.data["_metadata"][name]["aliases"] = aliases - self.data["_metadata"][name]["converter"] = converter - - def set_delta(self, name, delta): - if "_metadata" not in self.data.keys(): - self.data["_metadata"] = {} - if name not in self.data["_metadata"].keys(): - self.data["_metadata"][name] = {} - self.data["_metadata"][name]["default"] = None - self.data["_metadata"][name]["aliases"] = None - self.data["_metadata"][name]["delta"] = delta - - def set_converter(self, name, converter): - if "_metadata" not in self.data.keys(): - self.data["_metadata"] = {} - if name not in self.data["_metadata"].keys(): - self.data["_metadata"][name] = {} - self.data["_metadata"][name]["default"] = None - self.data["_metadata"][name]["aliases"] = None - self.data["_metadata"][name]["converter"] = converter - - def set_alias(self, name, setter, getter): - if "_metadata" not in self.data.keys(): - self.data["_metadata"] = {} - if name not in self.data["_metadata"].keys(): - self.data["_metadata"][name] = {} - self.data["_metadata"][name]["default"] = None - self.data["_metadata"][name]["aliases"] = setter - self.data[f"_alias_getter_{name}"] = lambda: getter(self) - - def set_alias_getter(self, name, getter): - self.data[f"_alias_getter_{name}"] = lambda: getter(self) - - def __setitem__(self, key, value): - if ( - key != "_metadata" - and not key.startswith("_alias_getter") - and ( - "_metadata" not in self.data.keys() - or key not in self.data["_metadata"].keys() - ) - ): - self.add_field(key) - super().__setitem__(key, value) - - def __getattr__(self, key): - if f"_alias_getter_{key}" in self.data and isinstance( - self.data[f"_alias_getter_{key}"], Callable - ): - return self.data[f"_alias_getter_{key}"]() - if key in self.data: - return self.data[key] - raise AttributeError(f"'StateDict' object has no attribute '{key}'") - - def __add__(self, other: Union[Delta, Mapping]): - updates = dict() - other_fields_unused = list(other.keys()) - for self_key in self.data: # Access the data dictionary within UserDict - if self_key == "_metadata" or self_key.startswith("_alias_getter"): - continue - other_value, other_key = self._get_value(self_key, other) - if other_value is None: - continue - other_fields_unused.remove(other_key) - - self_field_key = self_key - self_value = self.data[ - self_field_key - ] # Access the value from the data dictionary - delta_behavior = self.data["_metadata"][self_field_key]["delta"] - - if ( - constructor := self.data["_metadata"][self_field_key]["converter"] - ) is not None: - coerced_other_value = constructor(other_value) - else: - coerced_other_value = other_value - - if delta_behavior == "extend": - extended_value = _extend(self_value, coerced_other_value) - updates[self_field_key] = extended_value - elif delta_behavior == "append": - appended_value = _append(self_value, coerced_other_value) - updates[self_field_key] = appended_value - elif delta_behavior == "replace": - updates[self_field_key] = coerced_other_value - else: - raise NotImplementedError( - "delta_behaviour=`%s` not implemented" % delta_behavior - ) - - new_data = self.data.copy() - new_data.update(updates) - new = self.__class__( - new_data - ) # Create a new instance of the same class with updated data - return new - - def _get_value(self, k, other: Union[Delta, Mapping]): - """ - Given a `StateDicts`'s `key` k, get a value from `other` and report its name. - - Returns: a tuple (the value, the key associated with that value) - - Examples: - >>> s = StateDict() - >>> s.add_field('a') - >>> s.add_field('b', aliases={"ba": lambda b: [b]}) - >>> s.add_field('c', aliases={"ca": lambda x: x, "cb": lambda x: [x]}) - - For a field with no aliases, we retrieve values with the base name: - >>> s._get_value('a', Delta(a=1)) - (1, 'a') - - ... and only the base name: - >>> s._get_value('a', Delta(b=2)) # no match for b - (None, None) - - Any other names._get_valueare unimportant: - >>> s._get_value('a', Delta(b=2, a=1)) - (1, 'a') - - For fields with an alias, we retrieve values with the base name: - >>> s._get_value('b', Delta(b=[2])) - ([2], 'b') - - ... or for the alias name, transformed by the alias lambda function: - >>> s._get_value('b', Delta(ba=21)) - ([21], 'ba') - - We preferentially get the base name, and then any aliases: - >>> s._get_value('b', Delta(b=2, ba=21)) - (2, 'b') - - ... regardless of their order in the `Delta` object: - >>> s._get_value('b', Delta(ba=21, b=2)) - (2, 'b') - - Other names are ignored: - >>> s._get_value('b', Delta(a=1)) - (None, None) - - and the order of other names is unimportant: - >>> s._get_value('b', Delta(a=1, b=2)) - (2, 'b') - - For fields with multiple aliases, we retrieve values with the base name: - >>> s._get_value('c', Delta(c=[3])) - ([3], 'c') - - ... for any alias: - >>> s._get_value('c', Delta(ca=31)) - (31, 'ca') - - ... transformed by the alias lambda function : - >>> s._get_value('c', Delta(cb=32)) - ([32], 'cb') - - ... and ignoring any other names: - >>> s._get_value('c', Delta(a=1)) - (None, None) - - ... preferentially in the order base name, 1st alias, 2nd alias, ... nth alias: - >>> s._get_value('c', Delta(c=3, ca=31, cb=32)) - (3, 'c') - - >>> s._get_value('c', Delta(ca=31, cb=32)) - (31, 'ca') - - >>> s._get_value('c', Delta(cb=32)) - ([32], 'cb') - - >>> s._get_value('c', Delta()) - (None, None) - - This works with dict objects: - >>> s._get_value('a', dict(a=13)) - (13, 'a') - - ... with multiple keys: - >>> s._get_value('b', dict(a=13, b=24, c=35)) - (24, 'b') - - ... and with aliases: - >>> s._get_value('b', dict(ba=222)) - ([222], 'ba') - - This works with UserDicts: - >>> class MyDelta(UserDict): - ... pass - - >>> s._get_value('a', MyDelta(a=14)) - (14, 'a') - - ... with multiple keys: - >>> s._get_value('b', MyDelta(a=1, b=4, c=9)) - (4, 'b') - - ... and with aliases: - >>> s._get_value('b', MyDelta(ba=234)) - ([234], 'ba') - - """ - - aliases = self.data["_metadata"][k].get("aliases", {}) - - value, used_key = None, None - - if k in other.keys(): - value = other[k] - used_key = k - elif aliases: # ... is not an empty dict - for alias_key, wrapping_function in aliases.items(): - if alias_key in other: - value = wrapping_function(other[alias_key]) - used_key = alias_key - break # we only evaluate the first match - - return value, used_key - - -State = StateDict - - @dataclass(frozen=True) -class StateDataClass: +class State: """ Base object for dataclasses which use the Delta mechanism. @@ -445,7 +59,7 @@ class StateDataClass: We define a dataclass where each field (which is going to be delta-ed) has additional metadata "delta" which describes its delta behaviour. >>> @dataclass(frozen=True) - ... class ListState(StateDataClass): + ... class ListState(State): ... l: List = field(default_factory=list, metadata={"delta": "extend"}) ... m: List = field(default_factory=list, metadata={"delta": "replace"}) @@ -487,7 +101,7 @@ class StateDataClass: We can also define fields which `append` the last result: >>> @dataclass(frozen=True) - ... class AppendState(StateDataClass): + ... class AppendState(State): ... n: List = field(default_factory=list, metadata={"delta": "append"}) >>> m = AppendState(n=list("ɑβɣ")) @@ -501,7 +115,7 @@ class StateDataClass: The metadata key "converter" is used to coerce types (inspired by [PEP 712](https://peps.python.org/pep-0712/)): >>> @dataclass(frozen=True) - ... class CoerceStateList(StateDataClass): + ... class CoerceStateList(State): ... o: Optional[List] = field(default=None, metadata={"delta": "replace"}) ... p: List = field(default_factory=list, metadata={"delta": "replace", ... "converter": list}) @@ -522,7 +136,7 @@ class StateDataClass: With a converter, inputs are converted to the type output by the converter: >>> @dataclass(frozen=True) - ... class CoerceStateDataFrame(StateDataClass): + ... class CoerceStateDataFrame(State): ... q: pd.DataFrame = field(default_factory=pd.DataFrame, ... metadata={"delta": "replace", ... "converter": pd.DataFrame}) @@ -571,7 +185,7 @@ class StateDataClass: Without a converter: >>> @dataclass(frozen=True) - ... class CoerceStateDataFrameNoConverter(StateDataClass): + ... class CoerceStateDataFrameNoConverter(State): ... r: pd.DataFrame = field(default_factory=pd.DataFrame, metadata={"delta": "replace"}) ... there is no coercion – the object is passed unchanged @@ -585,7 +199,7 @@ class StateDataClass: A converter can cast from a DataFrame to a np.ndarray (with a single datatype), for instance: >>> @dataclass(frozen=True) - ... class CoerceStateArray(StateDataClass): + ... class CoerceStateArray(State): ... r: Optional[np.ndarray] = field(default=None, ... metadata={"delta": "replace", ... "converter": np.asarray}) @@ -599,7 +213,7 @@ class StateDataClass: names. >>> @dataclass(frozen=True) - ... class FieldAliasState(StateDataClass): + ... class FieldAliasState(State): ... things: List[str] = field( ... default_factory=list, ... metadata={"delta": "extend", @@ -830,7 +444,7 @@ def _get_value(f, other: Union[Delta, Mapping]): return value, used_key -def _get_field_names_and_aliases(s: StateDataClass): +def _get_field_names_and_aliases(s: State): """ Get a list of field names and their aliases from a State object @@ -842,14 +456,14 @@ def _get_field_names_and_aliases(s: StateDataClass): Examples: >>> from dataclasses import field >>> @dataclass(frozen=True) - ... class SomeState(StateDataClass): + ... class SomeState(State): ... l: List = field(default_factory=list) ... m: List = field(default_factory=list) >>> _get_field_names_and_aliases(SomeState()) ['l', 'm'] >>> @dataclass(frozen=True) - ... class SomeStateWithAliases(StateDataClass): + ... class SomeStateWithAliases(State): ... l: List = field(default_factory=list, metadata={"aliases": {"l1": None, "l2": None}}) ... m: List = field(default_factory=list, metadata={"aliases": {"m1": None}}) >>> _get_field_names_and_aliases(SomeStateWithAliases()) @@ -1044,7 +658,7 @@ def inputs_from_state(f, input_mapping: Dict = {}): The `State` it operates on needs to have the metadata described in the state module: >>> @dataclass(frozen=True) - ... class U(StateDataClass): + ... class U(State): ... conditions: List[int] = field(metadata={"delta": "replace"}) We indicate the inputs required by the parameter names. @@ -1097,7 +711,7 @@ def inputs_from_state(f, input_mapping: Dict = {}): ... return model >>> @dataclass(frozen=True) - ... class V(StateDataClass): + ... class V(State): ... variables: VariableCollection # field(metadata={"delta":... }) omitted ∴ immutable ... experiment_data: pd.DataFrame = field(metadata={"delta": "extend"}) ... model: Optional[BaseEstimator] = field(metadata={"delta": "replace"}, default=None) @@ -1335,7 +949,7 @@ def delta_to_state(f): The `State` it operates on needs to have the metadata described in the state module: >>> @dataclass(frozen=True) - ... class U(StateDataClass): + ... class U(State): ... conditions: List[int] = field(metadata={"delta": "replace"}) We indicate the inputs required by the parameter names. @@ -1403,7 +1017,7 @@ def delta_to_state(f): ... return Delta(model=new_model) >>> @dataclass(frozen=True) - ... class V(StateDataClass): + ... class V(State): ... variables: VariableCollection # field(metadata={"delta":... }) omitted ∴ immutable ... experiment_data: pd.DataFrame = field(metadata={"delta": "extend"}) ... model: Optional[BaseEstimator] = field(metadata={"delta": "replace"}, default=None) @@ -1505,7 +1119,7 @@ def on_state( The `State` it operates on needs to have the metadata described in the state module: >>> @dataclass(frozen=True) - ... class W(StateDataClass): + ... class W(State): ... conditions: List[int] = field(metadata={"delta": "replace"}) We indicate the inputs required by the parameter names. @@ -1577,7 +1191,7 @@ def decorator(f): return decorator(function) -StateFunction = Callable[[StateDataClass], StateDataClass] +StateFunction = Callable[[State], State] class StandardStateVariables(Enum): @@ -1588,24 +1202,24 @@ class StandardStateVariables(Enum): @dataclass(frozen=True) -class StandardStateDataClass(StateDataClass): +class StandardState(State): """ Examples: The state can be initialized emtpy >>> from autora.variable import VariableCollection, Variable - >>> s = StandardStateDataClass() + >>> s = StandardState() >>> s - StandardStateDataClass(variables=None, conditions=None, experiment_data=None, models=[]) + StandardState(variables=None, conditions=None, experiment_data=None, models=[]) The `variables` can be updated using a `Delta`: >>> dv1 = Delta(variables=VariableCollection(independent_variables=[Variable("1")])) >>> s + dv1 # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS - StandardStateDataClass(variables=VariableCollection(independent_variables=[Variable(name='1',...) + StandardState(variables=VariableCollection(independent_variables=[Variable(name='1',...) ... and are replaced by each `Delta`: >>> dv2 = Delta(variables=VariableCollection(independent_variables=[Variable("2")])) >>> s + dv1 + dv2 # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS - StandardStateDataClass(variables=VariableCollection(independent_variables=[Variable(name='2',...) + StandardState(variables=VariableCollection(independent_variables=[Variable(name='2',...) The `conditions` can be updated using a `Delta`: >>> dc1 = Delta(conditions=pd.DataFrame({"x": [1, 2, 3]})) @@ -1754,48 +1368,6 @@ def model(self): return None -class StandardStateDict(StateDict): - def __init__(self, data: Optional[Dict] = None, **kwargs): - if data is None: - data = { - "_metadata": { - "variables": { - "default": None, - "delta": "replace", - "converter": VariableCollection, - }, - "conditions": { - "default": None, - "delta": "replace", - "converter": pd.DataFrame, - }, - "experiment_data": { - "default": None, - "delta": "extend", - "converter": pd.DataFrame, - }, - "models": {"default": None, "delta": "extend", "converter": list}, - }, - "variables": None, - "conditions": None, - "experiment_data": None, - "models": None, - } - super().__init__(data) - for key in kwargs: - self.data[key] = kwargs[key] - - @property - def model(self): - """Alias for the last model in the `models`.""" - try: - return self.data["models"][-1] - except IndexError: - return None - - -StandardState = StandardStateDict - X = TypeVar("X") Y = TypeVar("Y") XY = TypeVar("XY") @@ -1813,18 +1385,18 @@ def estimator_on_state(estimator: BaseEstimator) -> StateFunction: >>> from sklearn.linear_model import LinearRegression >>> state_fn = estimator_on_state(LinearRegression()) - Define the state on which to operate (here an instance of the `StandardStateDataClass`): - >>> from autora.state import StandardStateDataClass + Define the state on which to operate (here an instance of the `StandardState`): + >>> from autora.state import StandardState >>> from autora.variable import Variable, VariableCollection >>> import pandas as pd - >>> s = StandardStateDataClass( + >>> s = StandardState( ... variables=VariableCollection( ... independent_variables=[Variable("x")], ... dependent_variables=[Variable("y")]), ... experiment_data=pd.DataFrame({"x": [1,2,3], "y":[3,6,9]}) ... ) - Run the function, which fits the model and adds the result to the `StandardStateDataClass` + Run the function, which fits the model and adds the result to the `StandardState` >>> state_fn(s).model.coef_ array([[3.]]) @@ -1848,9 +1420,9 @@ def experiment_runner_on_state(f: Callable[[X], XY]) -> StateFunction: returns both $x$ and $y$ values in a complete dataframe. Examples: - The conditions are some x-values in a StandardStateDataClass object: - >>> from autora.state import StandardStateDataClass - >>> s = StandardStateDataClass(conditions=pd.DataFrame({"x": [1, 2, 3]})) + The conditions are some x-values in a StandardState object: + >>> from autora.state import StandardState + >>> s = StandardState(conditions=pd.DataFrame({"x": [1, 2, 3]})) The function can be defined on a DataFrame, allowing the explicit inclusion of metadata like column names. @@ -1871,8 +1443,7 @@ def experiment_runner_on_state(f: Callable[[X], XY]) -> StateFunction: ... return result With the relevant variables as conditions: - >>> t = StandardStateDataClass( \ -conditions=pd.DataFrame({"x0": [1, 2, 3], "x1": [10, 20, 30]})) + >>> t = StandardState(conditions=pd.DataFrame({"x0": [1, 2, 3], "x1": [10, 20, 30]})) >>> experiment_runner_on_state(xs_to_xy_fn)(t).experiment_data x0 x1 y 0 1 10 11 @@ -1904,7 +1475,7 @@ def combined_functions_on_state( Examples: >>> @dataclass(frozen=True) - ... class U(StateDataClass): + ... class U(State): ... conditions: List[int] = field(metadata={"delta": "replace"}) >>> identity = lambda conditions : conditions >>> double_conditions = combined_functions_on_state( @@ -1943,7 +1514,7 @@ def combined_functions_on_state( """ - def f_(_state: StateDataClass, params: Optional[Dict] = None): + def f_(_state: State, params: Optional[Dict] = None): result_delta = None for name, function in functions: _f_input_from_state = inputs_from_state(function) diff --git a/tests/_example_workflow_library.py b/tests/_example_workflow_library.py index 13cb4677..6e31df1e 100644 --- a/tests/_example_workflow_library.py +++ b/tests/_example_workflow_library.py @@ -2,12 +2,12 @@ from sklearn.linear_model import LinearRegression from autora.experimentalist.grid import grid_pool -from autora.state import StandardStateDataClass, estimator_on_state, on_state +from autora.state import StandardState, estimator_on_state, on_state from autora.variable import Variable, VariableCollection def initial_state(_): - state = StandardStateDataClass( + state = StandardState( variables=VariableCollection( independent_variables=[Variable(name="x", allowed_values=range(100))], dependent_variables=[Variable(name="y")], diff --git a/tests/test_state.py b/tests/test_state.py index b8f196fc..f89b52dc 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -3,17 +3,17 @@ import pandas as pd from hypothesis import HealthCheck, given, settings -from autora.state import StandardStateDataClass +from autora.state import StandardState from .test_serializer import serializer_dump_load_strategy -from .test_strategies import standard_state_dataclass_strategy +from .test_strategies import standard_state_strategy logger = logging.getLogger(__name__) -@given(standard_state_dataclass_strategy(), serializer_dump_load_strategy) +@given(standard_state_strategy(), serializer_dump_load_strategy) @settings(suppress_health_check={HealthCheck.too_slow}, deadline=1000) -def test_state_serialize_deserialize(o: StandardStateDataClass, dump_load): +def test_state_serialize_deserialize(o: StandardState, dump_load): o_loaded = dump_load(o) assert o.variables == o_loaded.variables assert pd.DataFrame.equals(o.conditions, o_loaded.conditions) diff --git a/tests/test_strategies.py b/tests/test_strategies.py index 62e95771..8683762b 100644 --- a/tests/test_strategies.py +++ b/tests/test_strategies.py @@ -10,7 +10,7 @@ from hypothesis.extra import numpy as st_np from hypothesis.extra import pandas as st_pd -from autora.state import StandardStateDataClass +from autora.state import StandardState from autora.variable import ValueType, Variable, VariableCollection VALUE_TYPE_DTYPE_MAPPING = { @@ -437,7 +437,7 @@ def test_model_strategy_creation(o): @st.composite -def standard_state_dataclass_strategy(draw): +def standard_state_strategy(draw): variable_collection: VariableCollection = draw(variablecollection_strategy()) conditions = draw( dataframe_strategy(variables=variable_collection.independent_variables) @@ -452,7 +452,7 @@ def standard_state_dataclass_strategy(draw): ) ) models = draw(st.lists(model_strategy(), min_size=0, max_size=5)) - s = StandardStateDataClass( + s = StandardState( variables=variable_collection, conditions=conditions, experiment_data=experiment_data, @@ -462,8 +462,8 @@ def standard_state_dataclass_strategy(draw): @settings(suppress_health_check={HealthCheck.too_slow}) -@given(standard_state_dataclass_strategy()) -def test_standard_state_dataclass_strategy_creation(o): +@given(standard_state_strategy()) +def test_standard_state_strategy_creation(o): assert o diff --git a/tests/test_workflow.py b/tests/test_workflow.py index f2b30224..a1417c7b 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -8,7 +8,7 @@ from hypothesis import strategies as st from autora.serializer import SupportedSerializer, load_state -from autora.state import StandardState, State +from autora.state import StandardState from autora.workflow.__main__ import main _logger = logging.getLogger(__name__) @@ -16,7 +16,7 @@ example_workflow_library_module = st.sampled_from(["_example_workflow_library"]) -def validate_model(state: Optional[State]): +def validate_model(state: Optional[StandardState]): assert state is not None assert state.conditions is not None