Skip to content

Commit

Permalink
Config and serialization enhancement (#86)
Browse files Browse the repository at this point in the history
This PR significantly enhances our serialization machinery:
* Search spaces can now be deserialized using a selected classmethod.
Currently, this is enabled by manually registering hooks for our search
space classes. However, if we add more constructors to other classes in
the future, we can consider automating the process using hook predicate
functions / factories. Perhaps, there will even be a dedicated
[strategy](python-attrs/cattrs#489) available
for this mechanism.
* Due to this change, there is no more separate "config converter"
required, since creation from config can now happen via the regular
converter. In fact, in future we can even think about deprecating our
"from_config" approach, because there really is no more such thing as a
"config" – it's just a "regular" JSON string that goes through the
default converter.
* The serialization functionality now sits in its own subpackage.
* Missing serialization mixins have been added to `Interval` and the two
subspace classes.
* Added a basic serialization roundtrip test for dataframes and a
corresponding hypothesis strategy.
* Changed binarization of dataframes to use regular pickle due to some
edge cases with the previous `parquet` approach that where detected
through the above test.
* Added a `SearchSpace.from_dataframe` convenience constructor for
consistency, which is also helps to simplify campaign configs.
  • Loading branch information
AdrianSosic authored Jan 24, 2024
2 parents a76d63a + abbc468 commit 4699927
Show file tree
Hide file tree
Showing 37 changed files with 750 additions and 303 deletions.
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Target enums
- `mypy` for targets and intervals
- Tests for code blocks in README and user guides
- `hypothesis` strategies and tests for targets and intervals
- `hypothesis` strategies and roundtrip tests for targets, intervals, and dataframes
- De-/serialization of target subclasses via base class
- Docs building check now part of CI
- Automatic formatting checks for code examples in documentation
- Deserialization of classes with classmethod constructors can now be customized
by providing an optional `constructor` field.
- `SearchSpace.from_dataframe` convenience constructor

### Changed
- Renamed `bounds_transform_func` target attribute to `transformation`
Expand All @@ -23,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `DiscreteCustomConstraint` validator now expects data frame instead of series
- `ignore_example` flag builds but does not execute examples when building documentation
- New user guide versions for campaigns, targets and objectives
- Binarization of dataframes now happens via pickling

### Fixed
- Wrong use of `tolerance` argument in constraints user guide
Expand All @@ -31,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Use pydoclint as flake8 plugin and not as a stand-alone linter
- Margins in documentation for desktop and mobile version
- `Interval`s can now also be deserialized from a bounds iterable
- `SubspaceDiscrete` and `SubspaceContinuous` now have de-/serialization methods

### Removed
- Conda install instructions and version badge
Expand All @@ -39,6 +44,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Deprecations
- `Interval.is_finite` replaced with `Interval.is_bounded`
- Specifying target configs without explicit type information is deprecated
- Specifying parameters/constraints at the top level of a campaign configuration JSON is
deprecated. Instead, an explicit `searchspace` field must be provided with an optional
`constructor` entry.

## [0.7.1] - 2023-12-07
### Added
Expand Down
46 changes: 12 additions & 34 deletions baybe/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from baybe.parameters.base import Parameter
from baybe.searchspace.core import (
SearchSpace,
structure_searchspace_from_config,
validate_searchspace_from_config,
)
from baybe.serialization import SerialMixin, converter
from baybe.strategies import TwoPhaseStrategy
from baybe.strategies.base import Strategy
from baybe.targets import NumericalTarget
Expand All @@ -25,13 +25,6 @@
telemetry_record_value,
)
from baybe.utils import eq_dataframe
from baybe.utils.serialization import SerialMixin, converter

# Converter for config deserialization
_config_converter = converter.copy()
_config_converter.register_structure_hook(
SearchSpace, structure_searchspace_from_config
)

# Converter for config validation
_validation_converter = converter.copy()
Expand Down Expand Up @@ -119,31 +112,14 @@ def from_config(cls, config_json: str) -> Campaign:
Returns:
The constructed campaign.
"""
config = json.loads(config_json)
config["searchspace"] = {
"parameters": config.pop("parameters"),
"constraints": config.pop("constraints", None),
}
return _config_converter.structure(config, Campaign)
from baybe.deprecation import compatibilize_config

@classmethod
def to_config(cls) -> str:
"""Extract the configuration of the campaign as JSON string.
Note: This is not yet implemented. Use
:func:`baybe.utils.serialization.SerialMixin.to_json` instead
config = json.loads(config_json)

Returns:
The configuration as JSON string.
# Temporarily enable backward compatibility
config = compatibilize_config(config)

Raises:
NotImplementedError: When trying to use this function.
"""
# TODO: Ideally, this should extract a "minimal" configuration, that is,
# default values should not be exported, which cattrs supports via the
# 'omit_if_default' option. Can be Implemented once the converter structure
# has been cleaned up.
raise NotImplementedError()
return converter.structure(config, Campaign)

@classmethod
def validate_config(cls, config_json: str) -> None:
Expand All @@ -152,11 +128,13 @@ def validate_config(cls, config_json: str) -> None:
Args:
config_json: The JSON that should be validated.
"""
from baybe.deprecation import compatibilize_config

config = json.loads(config_json)
config["searchspace"] = {
"parameters": config.pop("parameters"),
"constraints": config.pop("constraints", None),
}

# Temporarily enable backward compatibility
config = compatibilize_config(config)

_validation_converter.structure(config, Campaign)

def add_measurements(self, data: pd.DataFrame) -> None:
Expand Down
8 changes: 5 additions & 3 deletions baybe/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

from baybe.constraints.conditions import Condition
from baybe.parameters import NumericalContinuousParameter
from baybe.utils import (
DTypeFloatTorch,
from baybe.serialization import (
SerialMixin,
converter,
get_base_structure_hook,
unstructure_base,
)
from baybe.utils.serialization import converter
from baybe.utils import (
DTypeFloatTorch,
)


@define
Expand Down
2 changes: 1 addition & 1 deletion baybe/constraints/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from funcy import rpartial
from numpy.typing import ArrayLike

from baybe.utils import SerialMixin
from baybe.serialization import SerialMixin


def _is_not_close(x: ArrayLike, y: ArrayLike, rtol: float, atol: float) -> np.ndarray:
Expand Down
4 changes: 2 additions & 2 deletions baybe/constraints/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
ThresholdCondition,
_valid_logic_combiners,
)
from baybe.utils import Dummy
from baybe.utils.serialization import (
from baybe.serialization import (
block_deserialization_hook,
block_serialization_hook,
converter,
)
from baybe.utils import Dummy


@define
Expand Down
46 changes: 46 additions & 0 deletions baybe/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,49 @@ def __attrs_pre_init__(self):
"Please use the 'Campaign' class instead.",
DeprecationWarning,
)


def compatibilize_config(config: dict) -> dict:
"""Turn a legacy-format config into the new format."""
if "parameters" not in config:
return config

if "searchspace" in config:
raise ValueError(
"Something is wrong with your campaign config. "
"It neither adheres to the deprecated nor the new format."
)

warnings.warn(
'''
Specifying parameters/constraints at the top level of the
campaign configuration JSON is deprecated and will not be
supported in future releases.
Instead, use a dedicated "searchspace" field that can be
used to customize the creation of the search space,
offering the possibility to specify a desired constructor.
To replicate the old behavior, use
"""
...
"searchspace": {
"constructor": "from_product",
"parameters": <your parameter configuration>,
"constraints": <your constraints configuration>
}
...
"""
For the available constructors and the parameters they expect,
see `baybe.searchspace.core.SearchSpace`.''',
UserWarning,
)

config = config.copy()
config["searchspace"] = {
"constructor": "from_product",
"parameters": config.pop("parameters"),
"constraints": config.pop("constraints", None),
}

return config
3 changes: 2 additions & 1 deletion baybe/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from attr import define, field
from attr.validators import deep_iterable, in_, instance_of, min_len

from baybe.serialization import SerialMixin
from baybe.targets.base import Target
from baybe.targets.numerical import NumericalTarget
from baybe.utils import SerialMixin, geom_mean
from baybe.utils import geom_mean


def _normalize_weights(weights: List[float]) -> List[float]:
Expand Down
8 changes: 6 additions & 2 deletions baybe/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
from cattrs.gen import override

from baybe.parameters.enum import ParameterEncoding
from baybe.utils import SerialMixin, get_base_structure_hook, unstructure_base
from baybe.utils.serialization import converter
from baybe.serialization import (
SerialMixin,
converter,
get_base_structure_hook,
unstructure_base,
)

# TODO: Reactive slots in all classes once cached_property is supported:
# https://github.com/python-attrs/attrs/issues/164
Expand Down
83 changes: 83 additions & 0 deletions baybe/parameters/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Parameter utilities."""

from typing import Any, Callable, Collection, Dict, List, Optional, TypeVar

import pandas as pd

from baybe.parameters.base import Parameter

_TParameter = TypeVar("_TParameter", bound=Parameter)


def get_parameters_from_dataframe(
df: pd.DataFrame,
factory: Callable[[str, Collection[Any]], _TParameter],
parameters: Optional[List[_TParameter]] = None,
) -> List[_TParameter]:
"""Create a list of parameters from a dataframe.
Returns one parameter for each column of the given dataframe. By default,
the parameters are created using the provided factory, which takes the name
of the column and its unique values as arguments. However, there is also
the possibility to provide explicit parameter objects with names matching specific
columns of the dataframe, to bypass the parameter factory creation for those
columns. This allows finer control, for example, to specify custom parameter
attributes (e.g. specific optional arguments) compared to what would be provided
by the factory. Still, the pre-specified parameters are validated to ensure that
they are compatible with the contents of the dataframe.
Args:
df: The dataframe from which to create the parameters.
factory: A parameter factor, creating parameter objects for the columns
from the column name and the unique column values.
parameters: An optional list of parameter objects to bypass the factory
creation for columns whose names match with the parameter names.
Returns:
The combined parameter list, containing both the (validated) pre-specified
parameters and the parameters inferred from the dataframe.
Raises:
ValueError: If several parameters with identical names are provided.
ValueError: If a parameter was specified for which no match was found.
"""
# Turn the pre-specified parameters into a dict and check for duplicate names
specified_params: Dict[str, _TParameter] = {}
if parameters is not None:
for param in parameters:
if param.name in specified_params:
raise ValueError(
f"You provided several parameters with the name '{param.name}'."
)
specified_params[param.name] = param

# Try to find a parameter match for each dataframe column
parameters = []
for name, series in df.items():
assert isinstance(
name, str
), "The given dataframe must only contain string-valued column names."
unique_values = series.unique()

# If a match is found, assert that the values are in range
if match := specified_params.pop(name, None):
if not all(match.is_in_range(x) for x in unique_values):
raise ValueError(
f"The dataframe column '{name}' contains the values "
f"{unique_values}, which are outside the range of {match}."
)
parameters.append(match)

# Otherwise, create a new parameter using the factory
else:
param = factory(name, unique_values)
parameters.append(param)

# By now, all pre-specified parameters must have been used
if specified_params:
raise ValueError(
f"For the parameter(s) {list(specified_params.keys())}, "
f"no match could be found in the given dataframe."
)

return parameters
2 changes: 1 addition & 1 deletion baybe/recommenders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from baybe.exceptions import NotEnoughPointsLeftError
from baybe.searchspace import SearchSpace, SearchSpaceType
from baybe.utils.serialization import (
from baybe.serialization import (
converter,
get_base_structure_hook,
unstructure_base,
Expand Down
2 changes: 0 additions & 2 deletions baybe/searchspace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from baybe.searchspace.core import (
SearchSpace,
SearchSpaceType,
structure_searchspace_from_config,
validate_searchspace_from_config,
)
from baybe.searchspace.discrete import SubspaceDiscrete

__all__ = [
"structure_searchspace_from_config",
"validate_searchspace_from_config",
"SearchSpace",
"SearchSpaceType",
Expand Down
Loading

0 comments on commit 4699927

Please sign in to comment.