Skip to content

Commit

Permalink
Hypothesis for objectives (#197)
Browse files Browse the repository at this point in the history
Adds the hypothesis strategies for our objective classes.
  • Loading branch information
AdrianSosic authored Apr 11, 2024
2 parents 2ce161f + 05248e1 commit 70f50b9
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 97 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Hypothesis strategies for acquisition functions
- `Kernel` base class allowing to specify kernels
- `MaternKernel` class can be chosen for GP surrogates
- `hypothesis` strategies and roundtrip test for kernels
- `hypothesis` strategies and roundtrip test for kernels and objectives

### Changed
- `torch` numeric types are now loaded lazily
Expand Down Expand Up @@ -171,7 +171,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Wrong use of `tolerance` argument in constraints user guide
- Errors with generics and type aliases in documentation
- Deduplication bug in substance_data hypothesis
- Deduplication bug in substance_data `hypothesis` strategy
- Use pydoclint as flake8 plugin and not as a stand-alone linter
- Margins in documentation for desktop and mobile version
- `Interval`s can now also be deserialized from a bounds iterable
Expand Down
5 changes: 4 additions & 1 deletion baybe/objectives/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
from baybe.serialization.mixin import SerialMixin
from baybe.targets.base import Target

# TODO: Reactive slots in all classes once cached_property is supported:
# https://github.com/python-attrs/attrs/issues/164

@define(frozen=True)

@define(frozen=True, slots=False)
class Objective(ABC, SerialMixin):
"""Abstract base class for all objectives."""

Expand Down
49 changes: 23 additions & 26 deletions baybe/objectives/desirability.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""Functionality for desirability objectives."""

from collections.abc import Sequence
from functools import partial
from functools import cached_property, partial
from typing import Callable

import cattrs
import numpy as np
import numpy.typing as npt
import pandas as pd
from attrs import define, field
from attrs.validators import deep_iterable, instance_of, min_len
from attrs.validators import deep_iterable, gt, instance_of, min_len
from typing_extensions import TypeGuard

from baybe.objectives.base import Objective
Expand All @@ -20,24 +19,6 @@
from baybe.utils.numerical import geom_mean


def _normalize_weights(weights: Sequence[float]) -> tuple[float, ...]:
"""Normalize a collection of weights such that they sum to 1.
Args:
weights: The un-normalized weights.
Raises:
ValueError: If any of the weights is non-positive.
Returns:
The normalized weights.
"""
array = np.asarray(cattrs.structure(weights, tuple[float, ...]))
if not np.all(array > 0.0):
raise ValueError("All weights must be strictly positive.")
return tuple(array / array.sum())


def _is_all_numerical_targets(
x: tuple[Target, ...], /
) -> TypeGuard[tuple[NumericalTarget, ...]]:
Expand Down Expand Up @@ -79,17 +60,21 @@ def scalarize(
return func(values, weights=weights)


@define(frozen=True)
@define(frozen=True, slots=False)
class DesirabilityObjective(Objective):
"""An objective scalarizing multiple targets using desirability values."""

targets: tuple[Target, ...] = field(
_targets: tuple[Target, ...] = field(
converter=to_tuple,
validator=[min_len(2), deep_iterable(member_validator=instance_of(Target))], # type: ignore[type-abstract]
alias="targets",
)
"The targets considered by the objective."

weights: tuple[float, ...] = field(converter=_normalize_weights)
weights: tuple[float, ...] = field(
converter=lambda w: cattrs.structure(w, tuple[float, ...]),
validator=deep_iterable(member_validator=gt(0.0)),
)
"""The weights to balance the different targets.
By default, all targets are considered equally important."""

Expand All @@ -101,13 +86,15 @@ def _default_weights(self) -> tuple[float, ...]:
"""Create unit weights for all targets."""
return tuple(1.0 for _ in range(len(self.targets)))

@targets.validator
@_targets.validator
def _validate_targets(self, _, targets) -> None: # noqa: DOC101, DOC103
if not _is_all_numerical_targets(targets):
raise TypeError(
f"'{self.__class__.__name__}' currently only supports targets "
f"of type '{NumericalTarget.__name__}'."
)
if len({t.name for t in targets}) != len(targets):
raise ValueError("All target names must be unique.")
if not all(target._is_transform_normalized for target in targets):
raise ValueError(
"All targets must have normalized computational representations to "
Expand All @@ -123,6 +110,16 @@ def _validate_weights(self, _, weights) -> None: # noqa: DOC101, DOC103
f"Specified number of targets: {lt}. Specified number of weights: {lw}."
)

@property
def targets(self) -> tuple[Target, ...]: # noqa: D102
# See base class.
return self._targets

@cached_property
def _normalized_weights(self) -> np.ndarray:
"""The normalized target weights."""
return np.asarray(self.weights) / np.sum(self.weights)

def __str__(self) -> str:
start_bold = "\033[1m"
end_bold = "\033[0m"
Expand All @@ -147,7 +144,7 @@ def transform(self, data: pd.DataFrame) -> pd.DataFrame: # noqa: D102
transformed[target.name] = target.transform(data[[target.name]])

# Scalarize the transformed targets into desirability values
vals = scalarize(transformed.values, self.scalarizer, self.weights)
vals = scalarize(transformed.values, self.scalarizer, self._normalized_weights)

# Store the total desirability in a dataframe column
transformed = pd.DataFrame({"Desirability": vals}, index=transformed.index)
Expand Down
2 changes: 1 addition & 1 deletion baybe/objectives/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from baybe.targets.base import Target


@define(frozen=True)
@define(frozen=True, slots=False)
class SingleTargetObjective(Objective):
"""An objective focusing on a single target."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from baybe.parameters.categorical import TaskParameter
from baybe.searchspace import SearchSpace, SubspaceContinuous
from baybe.searchspace.discrete import SubspaceDiscrete
from tests.hypothesis_strategies.parameters import numerical_discrete_parameter
from tests.hypothesis_strategies.parameters import numerical_discrete_parameters

# Discrete inputs for testing
s_x = pd.Series([1, 2, 3], name="x")
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_searchspace_creation_from_dataframe(df, parameters, expected):
@pytest.mark.parametrize("boundary_only", (False, True))
@given(
parameters=st.lists(
numerical_discrete_parameter(min_value=0.0, max_value=1.0),
numerical_discrete_parameters(min_value=0.0, max_value=1.0),
min_size=1,
max_size=5,
unique_by=lambda x: x.name,
Expand Down
33 changes: 33 additions & 0 deletions tests/hypothesis_strategies/objectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Hypothesis strategies for objectives."""

import hypothesis.strategies as st

from baybe.objectives.desirability import DesirabilityObjective
from baybe.objectives.enum import Scalarizer
from baybe.objectives.single import SingleTargetObjective

from ..hypothesis_strategies.targets import numerical_targets
from ..hypothesis_strategies.utils import intervals as st_intervals


def single_target_objectives():
"""Generate :class:`baybe.objectives.single.SingleTargetObjective`."""
return st.builds(SingleTargetObjective, target=numerical_targets())


@st.composite
def desirability_objectives(draw: st.DrawFn):
"""Generate :class:`baybe.objectives.desirability.DesirabilityObjective`."""
intervals = st_intervals(exclude_fully_unbounded=True, exclude_half_bounded=True)
targets = draw(
st.lists(numerical_targets(intervals), min_size=2, unique_by=lambda t: t.name)
)
weights = draw(
st.lists(
st.floats(min_value=0.0, exclude_min=True),
min_size=len(targets),
max_size=len(targets),
)
)
scalarizer = draw(st.sampled_from(Scalarizer))
return DesirabilityObjective(targets, weights, scalarizer)
50 changes: 25 additions & 25 deletions tests/hypothesis_strategies/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from baybe.parameters.substance import SubstanceEncoding, SubstanceParameter
from baybe.utils.numerical import DTypeFloatNumpy

from .utils import interval
from .utils import intervals

decorrelation = st.one_of(
decorrelations = st.one_of(
st.booleans(),
st.floats(min_value=0.0, max_value=1.0, exclude_min=True, exclude_max=True),
)
"""A strategy that generates decorrelation settings."""

parameter_name = st.text(min_size=1)
parameter_names = st.text(min_size=1)
"""A strategy that generates parameter names."""

categories = st.lists(st.text(min_size=1), min_size=2, unique=True)
Expand Down Expand Up @@ -76,13 +76,13 @@ def custom_descriptors(draw: st.DrawFn):


@st.composite
def numerical_discrete_parameter(
def numerical_discrete_parameters(
draw: st.DrawFn,
min_value: Optional[float] = None,
max_value: Optional[float] = None,
):
"""Generate :class:`baybe.parameters.numerical.NumericalDiscreteParameter`."""
name = draw(parameter_name)
name = draw(parameter_names)
values = draw(
st.lists(
st.floats(
Expand Down Expand Up @@ -111,26 +111,26 @@ def numerical_discrete_parameter(


@st.composite
def numerical_continuous_parameter(draw: st.DrawFn):
def numerical_continuous_parameters(draw: st.DrawFn):
"""Generate :class:`baybe.parameters.numerical.NumericalContinuousParameter`."""
name = draw(parameter_name)
bounds = draw(interval(exclude_half_bounded=True, exclude_fully_unbounded=True))
name = draw(parameter_names)
bounds = draw(intervals(exclude_half_bounded=True, exclude_fully_unbounded=True))
return NumericalContinuousParameter(name=name, bounds=bounds)


@st.composite
def categorical_parameter(draw: st.DrawFn):
def categorical_parameters(draw: st.DrawFn):
"""Generate :class:`baybe.parameters.categorical.CategoricalParameter`."""
name = draw(parameter_name)
name = draw(parameter_names)
values = draw(categories)
encoding = draw(st.sampled_from(CategoricalEncoding))
return CategoricalParameter(name=name, values=values, encoding=encoding)


@st.composite
def task_parameter(draw: st.DrawFn):
def task_parameters(draw: st.DrawFn):
"""Generate :class:`baybe.parameters.categorical.TaskParameter`."""
name = draw(parameter_name)
name = draw(parameter_names)
values = draw(categories)
active_values = draw(
st.lists(st.sampled_from(values), min_size=1, max_size=len(values), unique=True)
Expand All @@ -139,34 +139,34 @@ def task_parameter(draw: st.DrawFn):


@st.composite
def substance_parameter(draw: st.DrawFn):
def substance_parameters(draw: st.DrawFn):
"""Generate :class:`baybe.parameters.substance.SubstanceParameter`."""
name = draw(parameter_name)
name = draw(parameter_names)
data = draw(substance_data())
decorrelate = draw(decorrelation)
decorrelate = draw(decorrelations)
encoding = draw(st.sampled_from(SubstanceEncoding))
return SubstanceParameter(
name=name, data=data, decorrelate=decorrelate, encoding=encoding
)


@st.composite
def custom_parameter(draw: st.DrawFn):
def custom_parameters(draw: st.DrawFn):
"""Generate :class:`baybe.parameters.custom.CustomDiscreteParameter`."""
name = draw(parameter_name)
name = draw(parameter_names)
data = draw(custom_descriptors())
decorrelate = draw(decorrelation)
decorrelate = draw(decorrelations)
return CustomDiscreteParameter(name=name, data=data, decorrelate=decorrelate)


parameter = st.one_of(
parameters = st.one_of(
[
numerical_discrete_parameter(),
numerical_continuous_parameter(),
categorical_parameter(),
task_parameter(),
substance_parameter(),
custom_parameter(),
numerical_discrete_parameters(),
numerical_continuous_parameters(),
categorical_parameters(),
task_parameters(),
substance_parameters(),
custom_parameters(),
]
)
"""A strategy that generates parameters."""
27 changes: 20 additions & 7 deletions tests/hypothesis_strategies/targets.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,45 @@
"""Hypothesis strategies for targets."""

from typing import Optional

import hypothesis.strategies as st

from baybe.targets.enum import TargetMode
from baybe.targets.numerical import _VALID_TRANSFORMATIONS, NumericalTarget
from baybe.utils.interval import Interval

from .utils import interval
from .utils import intervals as st_intervals

target_name = st.text(min_size=1)
"""A strategy that generates target names."""


@st.composite
def numerical_target(draw: st.DrawFn):
"""Generate :class:`baybe.targets.numerical.NumericalTarget`."""
def numerical_targets(
draw: st.DrawFn, bounds_strategy: Optional[st.SearchStrategy[Interval]] = None
):
"""Generate :class:`baybe.targets.numerical.NumericalTarget`.
Args:
draw: Hypothesis draw object.
bounds_strategy: An optional strategy for generating the target bounds.
Returns:
_type_: _description_
"""
name = draw(target_name)
mode = draw(st.sampled_from(TargetMode))
bounds = draw(
interval(
if bounds_strategy is None:
bounds_strategy = st_intervals(
exclude_half_bounded=True, exclude_fully_unbounded=mode is TargetMode.MATCH
)
)
bounds = draw(bounds_strategy)
transformation = draw(st.sampled_from(_VALID_TRANSFORMATIONS[mode]))

return NumericalTarget(
name=name, mode=mode, bounds=bounds, transformation=transformation
)


target = numerical_target()
targets = numerical_targets()
"""A strategy that generates targets."""
Loading

0 comments on commit 70f50b9

Please sign in to comment.