Skip to content

Commit

Permalink
refactor: make load_serializer function public
Browse files Browse the repository at this point in the history
  • Loading branch information
hollandjg committed Nov 29, 2023
1 parent a26b5bc commit 0f139c5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 45 deletions.
78 changes: 41 additions & 37 deletions src/autora/serializer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
"""A submodule which handles importing supported serializers."""

import importlib
import logging
import pathlib
from collections import namedtuple
from enum import Enum
from typing import Callable, Dict, Literal, Optional, Tuple, Union
from typing import Dict, Optional, Union

from autora.state import State

_logger = logging.getLogger(__name__)

# Developer notes:
# Add a new serializer by:
# - including its name in the SerializersSupported Enum. This {name} will be available on the
# command line: `python -m autora.workflow --serializer {name}`
# - Adding the basic data about it to the _SERIALIZER_INFO dictionary. This should include the
# fully qualified name under which it can be imported, and which file mode it expects:
# "" for regular [i.e. files will be opened as `open(filename, "w")` for writing]
# "b" for binary [i.e. files will be opened as `open(filename, "wb")` for writing]

_SERIALIZER_INFO_ENTRY = namedtuple(
"_SERIALIZER_INFO_ENTRY", ["fully_qualified_module_name", "file_mode"]
)
LOADED_SERIALIZER = namedtuple("LOADED_SERIALIZER", ["module", "file_mode"])


class SerializersSupported(str, Enum):
"""Listing of allowed serializers."""
Expand All @@ -19,55 +35,44 @@ class SerializersSupported(str, Enum):


# Dictionary of details about each serializer
_SERIALIZER_INFO_ENTRY = namedtuple(
"_SERIALIZER_INFO_ENTRY", ["module_path", "file_mode"]
)
_SERIALIZER_INFO: Dict[SerializersSupported, _SERIALIZER_INFO_ENTRY] = {
SerializersSupported.pickle: _SERIALIZER_INFO_ENTRY("pickle", "b"),
SerializersSupported.dill: _SERIALIZER_INFO_ENTRY("dill", "b"),
SerializersSupported.yaml: _SERIALIZER_INFO_ENTRY("autora.serializer.yaml_", ""),
}

# Import those serializers which are actually importable
_AVAILABLE_SERIALIZER_INFO = dict()
_LOADED_SERIALIZER_DEF = namedtuple(
"_LOADED_SERIALIZER_DEF", ["name", "module", "file_mode"]
)
for serializer_enum in SerializersSupported:
serializer_info = _SERIALIZER_INFO[serializer_enum]
try:
module = importlib.import_module(serializer_info.module_path)
except ImportError:
_logger.info(f"serializer {serializer_info.module_path} not available")
continue
_AVAILABLE_SERIALIZER_INFO[serializer_enum] = _LOADED_SERIALIZER_DEF(
serializer_info.module_path, module, serializer_info.file_mode
)

# Set the default serializer for the package
default_serializer = SerializersSupported.pickle

# A dictionary to handle lazy loading of the serializers
_LOADED_SERIALIZERS: Dict[SerializersSupported, LOADED_SERIALIZER] = dict()


def load_serializer(serializer: SerializersSupported) -> LOADED_SERIALIZER:
"""Load"""

try:
serializer_def = _LOADED_SERIALIZERS[serializer]

except KeyError:
serializer_info = _SERIALIZER_INFO[serializer]
module_ = importlib.import_module(serializer_info.fully_qualified_module_name)
serializer_def = LOADED_SERIALIZER(module_, serializer_info.file_mode)
_LOADED_SERIALIZERS[serializer] = serializer_def

def _get_serializer_mode(
serializer: SerializersSupported,
interface: Literal["load", "dump", "loads", "dumps"],
) -> Tuple[Callable, str]:
serializer_def = _AVAILABLE_SERIALIZER_INFO[serializer]
module = serializer_def.module
function = getattr(module, interface)
file_mode = serializer_def.file_mode
return function, file_mode
return serializer_def


def load_state(
path: Optional[pathlib.Path],
loader: SerializersSupported = default_serializer,
) -> Union[State, None]:
"""Load a State object from a path."""
serializer = load_serializer(loader)
if path is not None:
load, file_mode = _get_serializer_mode(loader, "load")
_logger.debug(f"load_state: loading from {path=}")
with open(path, f"r{file_mode}") as f:
state_ = load(f)
with open(path, f"r{serializer.file_mode}") as f:
state_ = serializer.module.load(f)
else:
_logger.debug(f"load_state: {path=} -> returning None")
state_ = None
Expand All @@ -80,14 +85,13 @@ def dump_state(
dumper: SerializersSupported = default_serializer,
) -> None:
"""Write a State object to a path."""
serializer = load_serializer(dumper)
if path is not None:
dump, file_mode = _get_serializer_mode(dumper, "dump")
_logger.debug(f"dump_state: dumping to {path=}")
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, f"w{file_mode}") as f:
dump(state_, f)
with open(path, f"w{serializer.file_mode}") as f:
serializer.module.dump(state_, f)
else:
dumps, _ = _get_serializer_mode(dumper, "dumps")
_logger.debug(f"dump_state: {path=} so writing to stdout")
print(dumps(state_))
print(serializer.module.dumps(state_))
return
10 changes: 2 additions & 8 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,17 @@

import hypothesis.strategies as st

from autora.serializer import _AVAILABLE_SERIALIZER_INFO, SerializersSupported
from autora.serializer import SerializersSupported, load_serializer

logger = logging.getLogger(__name__)


# Define an ordered list of serializers we're going to test.
# We use the same order as the SerializersSupported Enum.
AVAILABLE_SERIALIZERS = st.sampled_from(
[
_AVAILABLE_SERIALIZER_INFO[k]
for k in SerializersSupported
if k in _AVAILABLE_SERIALIZER_INFO
]
[load_serializer(s) for s in SerializersSupported]
)



@st.composite
def serializer_dump_load_string_strategy(draw):
"""Strategy returns a function which dumps an object and reloads it via a bytestream."""
Expand Down

0 comments on commit 0f139c5

Please sign in to comment.