diff --git a/pyproject.toml b/pyproject.toml index 2a63a4fe..13ca10cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,12 +11,7 @@ license = {file = "LICENSE"} dependencies = [ "autora-core>=4.0.0", - "scikit-learn", - "matplotlib", - "pandas", "typer[all]", - "dill", - "pyyaml", ] [project.optional-dependencies] @@ -26,16 +21,26 @@ dev = [ "autora-workflow[test]", ] docs = [ - "autora-core[docs]>=4.0.0" + "autora-core[docs]>=4.0.0", + "scikit-learn", + "matplotlib", + "pandas", ] test = [ "autora-core[test]>=4.0.0", - "hypothesis" + "autora-workflow[serializers]", + "hypothesis", + "scikit-learn", + "pandas", ] cylc = [ "cylc-flow", "cylc-uiserver" ] +serializers = [ + "dill", + "pyyaml" +] [project.urls] homepage = "http://www.empiricalresearch.ai/" diff --git a/src/autora/serializer/__init__.py b/src/autora/serializer/__init__.py new file mode 100644 index 00000000..137aa1de --- /dev/null +++ b/src/autora/serializer/__init__.py @@ -0,0 +1,82 @@ +import importlib +import logging +import pathlib +from collections import namedtuple +from enum import Enum +from typing import Callable, Dict, Literal, Optional, Tuple, Union + +from autora.state import State + +_logger = logging.getLogger(__name__) + + +class SerializersSupported(str, Enum): + """Listing of allowed serializers.""" + + pickle = "pickle" + dill = "dill" + yaml = "yaml" + + +_SerializerDef = namedtuple( + "_SerializerDef", ["module", "load", "dump", "dumps", "file_mode"] +) +_serializer_dict: Dict[SerializersSupported, _SerializerDef] = { + SerializersSupported.pickle: _SerializerDef("pickle", "load", "dump", "dumps", "b"), + SerializersSupported.yaml: _SerializerDef( + "autora.serializer._yaml", "load", "dump", "dumps", "" + ), + SerializersSupported.dill: _SerializerDef("dill", "load", "dump", "dumps", "b"), +} + +default_serializer = SerializersSupported.pickle + + +def _get_serializer_mode( + serializer: SerializersSupported, interface: Literal["load", "dump", "dumps"] +) -> Tuple[Callable, str]: + serializer_def = _serializer_dict[serializer] + module = serializer_def.module + interface_function_name = getattr(serializer_def, interface) + _logger.debug( + f"_get_serializer_mode: loading {interface_function_name=} from" f" {module=}" + ) + module = importlib.import_module(module) + function = getattr(module, interface_function_name) + file_mode = serializer_def.file_mode + return function, file_mode + + +def load_state( + path: Optional[pathlib.Path], + loader: SerializersSupported = default_serializer, +) -> Union[State, None]: + """Load a State object from a path.""" + if path is not None: + load, file_mode = _get_serializer_mode(loader, "load") + _logger.debug(f"load_state: loading from {path=}") + with open(path, f"r{file_mode}") as f: + state_ = load(f) + else: + _logger.debug(f"load_state: {path=} -> returning None") + state_ = None + return state_ + + +def dump_state( + state_: State, + path: Optional[pathlib.Path], + dumper: SerializersSupported = default_serializer, +) -> None: + """Write a State object to a path.""" + if path is not None: + dump, file_mode = _get_serializer_mode(dumper, "dump") + _logger.debug(f"dump_state: dumping to {path=}") + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, f"w{file_mode}") as f: + dump(state_, f) + else: + dumps, _ = _get_serializer_mode(dumper, "dumps") + _logger.debug(f"dump_state: {path=} so writing to stdout") + print(dumps(state_)) + return diff --git a/src/autora/workflow/serializer/yaml_.py b/src/autora/serializer/_yaml.py similarity index 71% rename from src/autora/workflow/serializer/yaml_.py rename to src/autora/serializer/_yaml.py index 569cf38e..0d47e47d 100644 --- a/src/autora/workflow/serializer/yaml_.py +++ b/src/autora/serializer/_yaml.py @@ -6,6 +6,11 @@ def dump(data, file): return +def dumps(data): + yaml.dumps(data, Dumper=yaml.Dumper) + return + + def load(file): result = yaml.load(file, Loader=yaml.Loader) return result diff --git a/src/autora/workflow/__main__.py b/src/autora/workflow/__main__.py index c04c8506..18062f17 100644 --- a/src/autora/workflow/__main__.py +++ b/src/autora/workflow/__main__.py @@ -1,39 +1,59 @@ import importlib import logging import pathlib -from typing import Optional, Union +from typing import Optional -import dill import typer from typing_extensions import Annotated -from autora.state import State +from autora.serializer import ( + SerializersSupported, + default_serializer, + dump_state, + load_state, +) _logger = logging.getLogger(__name__) def main( fully_qualified_function_name: Annotated[ - str, typer.Argument(help="Function to load") + str, + typer.Argument( + help="Fully qualified name of the function to load, like `module.function`" + ), ], in_path: Annotated[ Optional[pathlib.Path], - typer.Option(help="Path to a .dill file with the initial state"), + typer.Option(help="Path to a file with the initial state"), ] = None, + in_loader: Annotated[ + SerializersSupported, + typer.Option( + help="(de)serializer to load the data", + ), + ] = default_serializer, out_path: Annotated[ Optional[pathlib.Path], - typer.Option(help="Path to output the final state as a .dill file"), + typer.Option(help="Path to output the final state"), ] = None, + out_dumper: Annotated[ + SerializersSupported, + typer.Option( + help="serializer to save the data", + ), + ] = default_serializer, verbose: Annotated[bool, typer.Option(help="Turns on info logging level.")] = False, debug: Annotated[bool, typer.Option(help="Turns on debug logging level.")] = False, ): + """Run an arbitrary function on an optional input State object and save the output.""" _configure_logger(debug, verbose) - starting_state = _load_state(in_path) + starting_state = load_state(in_path, in_loader) _logger.info(f"Starting State: {starting_state}") function = _load_function(fully_qualified_function_name) ending_state = function(starting_state) _logger.info(f"Ending State: {ending_state}") - _dump_state(ending_state, out_path) + dump_state(ending_state, out_path, out_dumper) return @@ -47,18 +67,8 @@ def _configure_logger(debug, verbose): _logger.info("using INFO logging level") -def _load_state(path: Optional[pathlib.Path]) -> Union[State, None]: - if path is not None: - _logger.debug(f"_load_state: loading from {path=}") - with open(path, "rb") as f: - state_ = dill.load(f) - else: - _logger.debug(f"_load_state: {path=} -> returning None") - state_ = None - return state_ - - def _load_function(fully_qualified_function_name: str): + """Load a function by its fully qualified name, `module.function_name`""" _logger.debug(f"_load_function: Loading function {fully_qualified_function_name}") module_name, function_name = fully_qualified_function_name.rsplit(".", 1) module = importlib.import_module(module_name) @@ -67,17 +77,5 @@ def _load_function(fully_qualified_function_name: str): return function -def _dump_state(state_: State, path: Optional[pathlib.Path]) -> None: - if path is not None: - _logger.debug(f"_dump_state: dumping to {path=}") - path.parent.mkdir(parents=True, exist_ok=True) - with open(path, "wb") as f: - dill.dump(state_, f) - else: - _logger.debug(f"_dump_state: {path=} so writing to stdout") - print(dill.dumps(state_)) - return - - if __name__ == "__main__": typer.run(main) diff --git a/tests/test_load_dump_state.py b/tests/test_load_dump_state.py deleted file mode 100644 index 6c2b19c8..00000000 --- a/tests/test_load_dump_state.py +++ /dev/null @@ -1,20 +0,0 @@ -import pathlib -import tempfile - -from hypothesis import Verbosity, given, settings -from hypothesis import strategies as st - -from autora.state import StandardState -from autora.workflow.__main__ import _dump_state, _load_state - - -@given( - st.builds(StandardState, st.text(), st.text(), st.text(), st.lists(st.integers())) -) -@settings(verbosity=Verbosity.verbose) -def test_load_inverts_dump(s): - with tempfile.TemporaryDirectory() as dir: - path = pathlib.Path(dir, "x.dill") - print(path, s) - _dump_state(s, path) - assert _load_state(path) == s diff --git a/tests/test_serializer.py b/tests/test_serializer.py new file mode 100644 index 00000000..5ae34935 --- /dev/null +++ b/tests/test_serializer.py @@ -0,0 +1,24 @@ +import pathlib +import tempfile +import uuid + +from hypothesis import Verbosity, given, settings +from hypothesis import strategies as st + +from autora.serializer import SerializersSupported, dump_state, load_state +from autora.state import StandardState + + +@given( + st.builds(StandardState, st.text(), st.text(), st.text(), st.lists(st.integers())), + st.sampled_from(SerializersSupported), +) +@settings(verbosity=Verbosity.verbose) +def test_load_inverts_dump(s, serializer): + """Test that each serializer can be used to serialize and deserialize a state object.""" + with tempfile.TemporaryDirectory() as dir: + path = pathlib.Path(dir, f"{str(uuid.uuid4())}") + print(path, s) + + dump_state(s, path, dumper=serializer) + assert load_state(path, loader=serializer) == s diff --git a/tests/test_workflow.py b/tests/test_workflow.py new file mode 100644 index 00000000..60e27a4f --- /dev/null +++ b/tests/test_workflow.py @@ -0,0 +1,190 @@ +import logging +import pathlib +import tempfile +from typing import Optional + +import numpy as np +import pandas as pd +from hypothesis import Verbosity, given, settings +from hypothesis import strategies as st +from sklearn.linear_model import LinearRegression + +from autora.experimentalist.grid import grid_pool +from autora.serializer import SerializersSupported, load_state +from autora.state import StandardState, State, estimator_on_state, on_state +from autora.variable import Variable, VariableCollection +from autora.workflow.__main__ import main + +_logger = logging.getLogger(__name__) + + +def initial_state(_): + state = StandardState( + variables=VariableCollection( + independent_variables=[Variable(name="x", allowed_values=range(100))], + dependent_variables=[Variable(name="y")], + covariates=[], + ), + conditions=None, + experiment_data=pd.DataFrame({"x": [], "y": []}), + models=[], + ) + return state + + +experimentalist = on_state(grid_pool, output=["conditions"]) + +experiment_runner = on_state( + lambda conditions: conditions.assign(y=2 * conditions["x"] + 0.5), + output=["experiment_data"], +) + +theorist = estimator_on_state(LinearRegression(fit_intercept=True)) + + +def validate_model(state: Optional[State]): + assert state is not None + + assert state.conditions is not None + assert len(state.conditions) == 100 + + assert state.experiment_data is not None + assert len(state.experiment_data) == 100 + + assert state.model is not None + assert np.allclose(state.model.coef_, [[2.0]]) + assert np.allclose(state.model.intercept_, [[0.5]]) + + +def test_e2e_nominal(): + """Test a basic standard chain of CLI calls using the default serializer. + + Equivalent to: + $ python -m autora.workflow test_workflow.initial_state --out-path start + $ python -m autora.workflow test_workflow.experimentalist --in-path start --out-path conditions + $ python -m autora.workflow test_workflow.experiment_runner --in-path conditions --out-path data + $ python -m autora.workflow test_workflow.theorist --in-path data --out-path theory + """ + + with tempfile.TemporaryDirectory() as d: + main( + "test_workflow.initial_state", + out_path=pathlib.Path(d, "start"), + ) + main( + "test_workflow.experimentalist", + in_path=pathlib.Path(d, "start"), + out_path=pathlib.Path(d, "conditions"), + ) + main( + "test_workflow.experiment_runner", + in_path=pathlib.Path(d, "conditions"), + out_path=pathlib.Path(d, "data"), + ) + main( + "test_workflow.theorist", + in_path=pathlib.Path(d, "data"), + out_path=pathlib.Path(d, "theory"), + ) + + final_state = load_state(pathlib.Path(d, "theory")) + validate_model(final_state) + + +@given(st.sampled_from(SerializersSupported), st.booleans(), st.booleans()) +@settings(verbosity=Verbosity.verbose, deadline=500) +def test_e2e_serializers(serializer, verbose, debug): + """Test a basic standard chain of CLI calls using a single serializer.""" + + common_settings = dict( + in_loader=serializer, out_dumper=serializer, verbose=verbose, debug=debug + ) + + with tempfile.TemporaryDirectory() as d: + main( + "test_workflow.initial_state", + out_path=pathlib.Path(d, "start"), + **common_settings + ) + main( + "test_workflow.experimentalist", + in_path=pathlib.Path(d, "start"), + out_path=pathlib.Path(d, "conditions"), + **common_settings + ) + main( + "test_workflow.experiment_runner", + in_path=pathlib.Path(d, "conditions"), + out_path=pathlib.Path(d, "data"), + **common_settings + ) + main( + "test_workflow.theorist", + in_path=pathlib.Path(d, "data"), + out_path=pathlib.Path(d, "theory"), + **common_settings + ) + + final_state: StandardState = load_state( + pathlib.Path(d, "theory"), loader=serializer + ) + validate_model(final_state) + + +@given( + st.sampled_from(SerializersSupported), + st.sampled_from(SerializersSupported), + st.sampled_from(SerializersSupported), + st.sampled_from(SerializersSupported), + st.booleans(), + st.booleans(), +) +@settings(verbosity=Verbosity.verbose, deadline=500) +def test_e2e_valid_serializer_mix( + initial_serializer, + experimental_serializer, + experiment_runner_serializer, + theorist_serializer, + verbose, + debug, +): + """Test a basic standard chain of CLI calls using a mix of serializers.""" + + common_settings = dict(verbose=verbose, debug=debug) + + with tempfile.TemporaryDirectory() as d: + main( + "test_workflow.initial_state", + out_path=pathlib.Path(d, "start"), + out_dumper=initial_serializer, + **common_settings + ) + main( + "test_workflow.experimentalist", + in_path=pathlib.Path(d, "start"), + out_path=pathlib.Path(d, "conditions"), + in_loader=initial_serializer, + out_dumper=experimental_serializer, + **common_settings + ) + main( + "test_workflow.experiment_runner", + in_path=pathlib.Path(d, "conditions"), + out_path=pathlib.Path(d, "data"), + in_loader=experimental_serializer, + out_dumper=experiment_runner_serializer, + **common_settings + ) + main( + "test_workflow.theorist", + in_path=pathlib.Path(d, "data"), + out_path=pathlib.Path(d, "theory"), + in_loader=experiment_runner_serializer, + out_dumper=theorist_serializer, + **common_settings + ) + + final_state: StandardState = load_state( + pathlib.Path(d, "theory"), loader=theorist_serializer + ) + validate_model(final_state)