-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds the hypothesis strategies for our objective classes.
- Loading branch information
Showing
13 changed files
with
271 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
"""Hypothesis strategies for objectives.""" | ||
|
||
import hypothesis.strategies as st | ||
|
||
from baybe.objectives.desirability import DesirabilityObjective | ||
from baybe.objectives.enum import Scalarizer | ||
from baybe.objectives.single import SingleTargetObjective | ||
|
||
from ..hypothesis_strategies.targets import numerical_targets | ||
from ..hypothesis_strategies.utils import intervals as st_intervals | ||
|
||
|
||
def single_target_objectives(): | ||
"""Generate :class:`baybe.objectives.single.SingleTargetObjective`.""" | ||
return st.builds(SingleTargetObjective, target=numerical_targets()) | ||
|
||
|
||
@st.composite | ||
def desirability_objectives(draw: st.DrawFn): | ||
"""Generate :class:`baybe.objectives.desirability.DesirabilityObjective`.""" | ||
intervals = st_intervals(exclude_fully_unbounded=True, exclude_half_bounded=True) | ||
targets = draw( | ||
st.lists(numerical_targets(intervals), min_size=2, unique_by=lambda t: t.name) | ||
) | ||
weights = draw( | ||
st.lists( | ||
st.floats(min_value=0.0, exclude_min=True), | ||
min_size=len(targets), | ||
max_size=len(targets), | ||
) | ||
) | ||
scalarizer = draw(st.sampled_from(Scalarizer)) | ||
return DesirabilityObjective(targets, weights, scalarizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,45 @@ | ||
"""Hypothesis strategies for targets.""" | ||
|
||
from typing import Optional | ||
|
||
import hypothesis.strategies as st | ||
|
||
from baybe.targets.enum import TargetMode | ||
from baybe.targets.numerical import _VALID_TRANSFORMATIONS, NumericalTarget | ||
from baybe.utils.interval import Interval | ||
|
||
from .utils import interval | ||
from .utils import intervals as st_intervals | ||
|
||
target_name = st.text(min_size=1) | ||
"""A strategy that generates target names.""" | ||
|
||
|
||
@st.composite | ||
def numerical_target(draw: st.DrawFn): | ||
"""Generate :class:`baybe.targets.numerical.NumericalTarget`.""" | ||
def numerical_targets( | ||
draw: st.DrawFn, bounds_strategy: Optional[st.SearchStrategy[Interval]] = None | ||
): | ||
"""Generate :class:`baybe.targets.numerical.NumericalTarget`. | ||
Args: | ||
draw: Hypothesis draw object. | ||
bounds_strategy: An optional strategy for generating the target bounds. | ||
Returns: | ||
_type_: _description_ | ||
""" | ||
name = draw(target_name) | ||
mode = draw(st.sampled_from(TargetMode)) | ||
bounds = draw( | ||
interval( | ||
if bounds_strategy is None: | ||
bounds_strategy = st_intervals( | ||
exclude_half_bounded=True, exclude_fully_unbounded=mode is TargetMode.MATCH | ||
) | ||
) | ||
bounds = draw(bounds_strategy) | ||
transformation = draw(st.sampled_from(_VALID_TRANSFORMATIONS[mode])) | ||
|
||
return NumericalTarget( | ||
name=name, mode=mode, bounds=bounds, transformation=transformation | ||
) | ||
|
||
|
||
target = numerical_target() | ||
targets = numerical_targets() | ||
"""A strategy that generates targets.""" |
Oops, something went wrong.