Skip to content

Commit

Permalink
Make interval hypothesis strategy more efficient
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Apr 10, 2024
1 parent c56436c commit 689dff5
Showing 1 changed file with 48 additions and 16 deletions.
64 changes: 48 additions & 16 deletions tests/hypothesis_strategies/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
"""Hypothesis strategies for generating utility objects."""

from enum import Enum, auto

import hypothesis.extra.numpy as hnp
import hypothesis.strategies as st
from hypothesis import assume

from baybe.utils.interval import Interval


class IntervalType(Enum):
"""The possible types of an interval on the real number line."""

FULLY_UNBOUNDED = auto()
HALF_BOUNDED = auto()
BOUNDED = auto()


@st.composite
def intervals(
draw: st.DrawFn,
Expand All @@ -19,18 +29,40 @@ def intervals(
(exclude_bounded, exclude_half_bounded, exclude_fully_unbounded)
), "At least one Interval type must be allowed."

# Create interval from ordered pair of floats
bounds = (
st.tuples(st.floats(), st.floats()).map(sorted).filter(lambda x: x[0] < x[1])
)
interval = Interval.create(draw(bounds))

# Filter excluded intervals
if exclude_bounded:
assume(not interval.is_bounded)
if exclude_half_bounded:
assume(not interval.is_half_bounded)
if exclude_fully_unbounded:
assume(not interval.is_fully_unbounded)

return interval
# Draw the interval type from the allowed types
type_gate = {
IntervalType.FULLY_UNBOUNDED: not exclude_fully_unbounded,
IntervalType.HALF_BOUNDED: not exclude_half_bounded,
IntervalType.BOUNDED: not exclude_bounded,
}
allowed_types = [t for t, b in type_gate.items() if b]
interval_type = draw(st.sampled_from(allowed_types))

# A strategy producing finite floats
ffloats = st.floats(allow_infinity=False, allow_nan=False)

# Draw the bounds depending on the interval type
if interval_type is IntervalType.FULLY_UNBOUNDED:
bounds = (None, None)
elif interval_type is IntervalType.HALF_BOUNDED:
bounds = draw(
st.sampled_from(
[
(None, draw(ffloats)),
(draw(ffloats), None),
]
)
)
elif interval_type is IntervalType.BOUNDED:
bounds = draw(
hnp.arrays(
dtype=float,
shape=(2,),
elements=ffloats,
unique=True,
).map(sorted)
)
else:
raise RuntimeError("This line should be unreachable.")

return Interval.create(bounds)

0 comments on commit 689dff5

Please sign in to comment.