Skip to content

Commit

Permalink
Merge pull request #55 from AutoResearch/test/saving-loading-states
Browse files Browse the repository at this point in the history
test: saving and loading `StateDataClass` objects
  • Loading branch information
hollandjg authored Nov 30, 2023
2 parents ecd15a4 + 2c7a8f8 commit 24478cf
Show file tree
Hide file tree
Showing 11 changed files with 636 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/test-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ jobs:
python-version: ${{ matrix.python-version }}
cache: "pip"
- run: pip install ".[test]"
- run: pytest --doctest-modules --import-mode importlib
- run: pytest --hypothesis-profile ci

2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
rev: "v1.5.1"
hooks:
- id: mypy
additional_dependencies: [types-requests,scipy,pytest]
additional_dependencies: [types-requests,scipy,pytest,hypothesis,types-pyyaml]
language_version: python3.8
args:
- "--namespace-packages"
Expand Down
4 changes: 4 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from hypothesis import Verbosity, settings

settings.register_profile("ci", max_examples=1000)
settings.register_profile("debug", max_examples=10, verbosity=Verbosity.verbose)
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ dev = [
]
test = [
"pytest",
"hypothesis[pandas]",
"autora-core[serializers]"
]
build = [
"build",
Expand All @@ -68,6 +70,7 @@ docs = [
"mkdocs-jupyter",
"pymdown-extensions",
]
serializers = ["dill", "pyyaml"]

[tool.isort]
profile = "black"
Expand All @@ -90,3 +93,7 @@ requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"

[tool.setuptools_scm]

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "--doctest-modules --import-mode importlib"
21 changes: 21 additions & 0 deletions src/autora/serializer/yaml_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import yaml


def dump(data, file):
yaml.dump(data, file, Dumper=yaml.Dumper)
return


def load(file):
data = yaml.load(file, Loader=yaml.Loader)
return data


def dumps(data):
string = yaml.dump(data, Dumper=yaml.Dumper)
return string


def loads(string):
data = yaml.load(string, Loader=yaml.Loader)
return data
2 changes: 2 additions & 0 deletions src/autora/variable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
class ValueType(str, Enum):
"""Specifies supported value types supported by Variables."""

BOOLEAN = "boolean"
INTEGER = "integer"
REAL = "real"
SIGMOID = "sigmoid"
PROBABILITY = "probability" # single probability
Expand Down
Empty file added tests/__init__.py
Empty file.
86 changes: 86 additions & 0 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import importlib
import logging
import pathlib
import tempfile
import uuid
from collections import namedtuple

import hypothesis.strategies as st

logger = logging.getLogger(__name__)

_SUPPORTED_SERIALIZERS = [
("pickle", "b"),
("dill", "b"),
("autora.serializer.yaml_", ""),
]
_SERIALIZER_DEF = namedtuple("_SERIALIZER_DEF", ["name", "module", "file_type"])
_AVAILABLE_SERIALIZERS = []

for module_name, file_type in _SUPPORTED_SERIALIZERS:
try:
module = importlib.import_module(module_name)
except ImportError:
logger.info(f"serializer {module} not available")
continue
_AVAILABLE_SERIALIZERS.append(_SERIALIZER_DEF(module_name, module, file_type))

AVAILABLE_SERIALIZERS = st.sampled_from(_AVAILABLE_SERIALIZERS)


@st.composite
def serializer_loads_dumps_strategy(draw):
serializer = draw(AVAILABLE_SERIALIZERS)
loads, dumps = serializer.module.loads, serializer.module.dumps
return loads, dumps


@st.composite
def serializer_dump_load_string_strategy(draw):
"""Strategy returns a function which dumps an object and reloads it via a bytestream."""
serializer = draw(AVAILABLE_SERIALIZERS)
loads, dumps = serializer.module.loads, serializer.module.dumps

def _load_dump_via_string(o):
logger.info(f"load dump via string using {serializer.module=}")
return loads(dumps(o))

return _load_dump_via_string


@st.composite
def serializer_dump_load_binary_file_strategy(draw):
"""Strategy returns a function which dumps an object reloads it via a temporary binary file."""
serializer = draw(AVAILABLE_SERIALIZERS)
load, dump = serializer.module.load, serializer.module.dump

def _load_dump_via_disk(o):
logger.info(f"load dump via disk using {serializer.module=}")
with tempfile.TemporaryDirectory() as tempdir:
filename = str(uuid.uuid1())
with open(pathlib.Path(tempdir, filename), f"w{serializer.file_type}") as f:
dump(o, f)
with open(pathlib.Path(tempdir, filename), f"r{serializer.file_type}") as f:
o_loaded = load(f)
return o_loaded

return _load_dump_via_disk


@st.composite
def serializer_dump_load_strategy(draw):
"""Strategy returns a function which dumps an object and reloads it via a supported method."""
_dump_load = draw(
st.one_of(
serializer_dump_load_string_strategy(),
serializer_dump_load_binary_file_strategy(),
)
)
return _dump_load


if __name__ == "__main__":
o = list("abcde")
loader_dumper_disk = serializer_dump_load_strategy().example()
o_loaded = loader_dumper_disk(o)
print(o, o_loaded)
20 changes: 20 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import logging

import pandas as pd
from hypothesis import HealthCheck, given, settings

from autora.state import StandardStateDataClass

from .test_serializer import serializer_dump_load_strategy
from .test_strategies import standard_state_dataclass_strategy

logger = logging.getLogger(__name__)


@given(standard_state_dataclass_strategy(), serializer_dump_load_strategy())
@settings(suppress_health_check={HealthCheck.too_slow}, deadline=1000)
def test_state_serialize_deserialize(o: StandardStateDataClass, dump_load):
o_loaded = dump_load(o)
assert o.variables == o_loaded.variables
assert pd.DataFrame.equals(o.conditions, o_loaded.conditions)
assert pd.DataFrame.equals(o.experiment_data, o_loaded.experiment_data)
Loading

0 comments on commit 24478cf

Please sign in to comment.