diff --git a/src/autora/serializer/__init__.py b/src/autora/serializer/__init__.py index d43e9725..a36f7687 100644 --- a/src/autora/serializer/__init__.py +++ b/src/autora/serializer/__init__.py @@ -1,14 +1,30 @@ +"""A submodule which handles importing supported serializers.""" + import importlib import logging import pathlib from collections import namedtuple from enum import Enum -from typing import Callable, Dict, Literal, Optional, Tuple, Union +from typing import Dict, Optional, Union from autora.state import State _logger = logging.getLogger(__name__) +# Developer notes: +# Add a new serializer by: +# - including its name in the SerializersSupported Enum. This {name} will be available on the +# command line: `python -m autora.workflow --serializer {name}` +# - Adding the basic data about it to the _SERIALIZER_INFO dictionary. This should include the +# fully qualified name under which it can be imported, and which file mode it expects: +# "" for regular [i.e. files will be opened as `open(filename, "w")` for writing] +# "b" for binary [i.e. files will be opened as `open(filename, "wb")` for writing] + +_SERIALIZER_INFO_ENTRY = namedtuple( + "_SERIALIZER_INFO_ENTRY", ["fully_qualified_module_name", "file_mode"] +) +LOADED_SERIALIZER = namedtuple("LOADED_SERIALIZER", ["module", "file_mode"]) + class SerializersSupported(str, Enum): """Listing of allowed serializers.""" @@ -19,43 +35,32 @@ class SerializersSupported(str, Enum): # Dictionary of details about each serializer -_SERIALIZER_INFO_ENTRY = namedtuple( - "_SERIALIZER_INFO_ENTRY", ["module_path", "file_mode"] -) _SERIALIZER_INFO: Dict[SerializersSupported, _SERIALIZER_INFO_ENTRY] = { SerializersSupported.pickle: _SERIALIZER_INFO_ENTRY("pickle", "b"), SerializersSupported.dill: _SERIALIZER_INFO_ENTRY("dill", "b"), SerializersSupported.yaml: _SERIALIZER_INFO_ENTRY("autora.serializer.yaml_", ""), } -# Import those serializers which are actually importable -_AVAILABLE_SERIALIZER_INFO = dict() -_LOADED_SERIALIZER_DEF = namedtuple( - "_LOADED_SERIALIZER_DEF", ["name", "module", "file_mode"] -) -for serializer_enum in SerializersSupported: - serializer_info = _SERIALIZER_INFO[serializer_enum] - try: - module = importlib.import_module(serializer_info.module_path) - except ImportError: - _logger.info(f"serializer {serializer_info.module_path} not available") - continue - _AVAILABLE_SERIALIZER_INFO[serializer_enum] = _LOADED_SERIALIZER_DEF( - serializer_info.module_path, module, serializer_info.file_mode - ) - +# Set the default serializer for the package default_serializer = SerializersSupported.pickle +# A dictionary to handle lazy loading of the serializers +_LOADED_SERIALIZERS: Dict[SerializersSupported, LOADED_SERIALIZER] = dict() + + +def load_serializer(serializer: SerializersSupported) -> LOADED_SERIALIZER: + """Load""" + + try: + serializer_def = _LOADED_SERIALIZERS[serializer] + + except KeyError: + serializer_info = _SERIALIZER_INFO[serializer] + module_ = importlib.import_module(serializer_info.fully_qualified_module_name) + serializer_def = LOADED_SERIALIZER(module_, serializer_info.file_mode) + _LOADED_SERIALIZERS[serializer] = serializer_def -def _get_serializer_mode( - serializer: SerializersSupported, - interface: Literal["load", "dump", "loads", "dumps"], -) -> Tuple[Callable, str]: - serializer_def = _AVAILABLE_SERIALIZER_INFO[serializer] - module = serializer_def.module - function = getattr(module, interface) - file_mode = serializer_def.file_mode - return function, file_mode + return serializer_def def load_state( @@ -63,11 +68,11 @@ def load_state( loader: SerializersSupported = default_serializer, ) -> Union[State, None]: """Load a State object from a path.""" + serializer = load_serializer(loader) if path is not None: - load, file_mode = _get_serializer_mode(loader, "load") _logger.debug(f"load_state: loading from {path=}") - with open(path, f"r{file_mode}") as f: - state_ = load(f) + with open(path, f"r{serializer.file_mode}") as f: + state_ = serializer.module.load(f) else: _logger.debug(f"load_state: {path=} -> returning None") state_ = None @@ -80,14 +85,13 @@ def dump_state( dumper: SerializersSupported = default_serializer, ) -> None: """Write a State object to a path.""" + serializer = load_serializer(dumper) if path is not None: - dump, file_mode = _get_serializer_mode(dumper, "dump") _logger.debug(f"dump_state: dumping to {path=}") path.parent.mkdir(parents=True, exist_ok=True) - with open(path, f"w{file_mode}") as f: - dump(state_, f) + with open(path, f"w{serializer.file_mode}") as f: + serializer.module.dump(state_, f) else: - dumps, _ = _get_serializer_mode(dumper, "dumps") _logger.debug(f"dump_state: {path=} so writing to stdout") - print(dumps(state_)) + print(serializer.module.dumps(state_)) return diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 247f3f11..52b758f4 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -5,23 +5,17 @@ import hypothesis.strategies as st -from autora.serializer import _AVAILABLE_SERIALIZER_INFO, SerializersSupported +from autora.serializer import SerializersSupported, load_serializer logger = logging.getLogger(__name__) # Define an ordered list of serializers we're going to test. -# We use the same order as the SerializersSupported Enum. AVAILABLE_SERIALIZERS = st.sampled_from( - [ - _AVAILABLE_SERIALIZER_INFO[k] - for k in SerializersSupported - if k in _AVAILABLE_SERIALIZER_INFO - ] + [load_serializer(s) for s in SerializersSupported] ) - @st.composite def serializer_dump_load_string_strategy(draw): """Strategy returns a function which dumps an object and reloads it via a bytestream."""