Skip to content

Commit

Permalink
adding changes to docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mmahsereci committed Apr 26, 2024
1 parent 695fa22 commit e24fb06
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 19 deletions.
6 changes: 6 additions & 0 deletions emukit/quadrature/loop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@


from .bayesian_monte_carlo_loop import BayesianMonteCarlo # noqa: F401
from .bq_loop_state import QuadratureLoopState
from .bq_outer_loop import QuadratureOuterLoop
from .bq_stopping_conditions import CoefficientOfVariationStoppingCondition
from .vanilla_bq_loop import VanillaBayesianQuadratureLoop # noqa: F401
from .wsabil_loop import WSABILLoop # noqa: F401

__all__ = [
"QuadratureOuterLoop",
"BayesianMonteCarlo",
"VanillaBayesianQuadratureLoop",
"WSABILLoop",
"QuadratureLoopState",
"point_calculators",
"CoefficientOfVariationStoppingCondition",
]
9 changes: 5 additions & 4 deletions emukit/quadrature/loop/bayesian_monte_carlo_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from ...core.loop import FixedIntervalUpdater, ModelUpdater, OuterLoop
from ...core.loop.loop_state import create_loop_state
from ...core.loop import FixedIntervalUpdater, ModelUpdater
from ...core.parameter_space import ParameterSpace
from ..loop.point_calculators import BayesianMonteCarloPointCalculator
from ..methods import WarpedBayesianQuadratureModel
from .bq_loop_state import create_bq_loop_state
from .bq_outer_loop import QuadratureOuterLoop


class BayesianMonteCarlo(OuterLoop):
class BayesianMonteCarlo(QuadratureOuterLoop):
"""The loop for Bayesian Monte Carlo (BMC).
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self, model: WarpedBayesianQuadratureModel, model_updater: ModelUpd

space = ParameterSpace(model.reasonable_box_bounds.convert_to_list_of_continuous_parameters())
candidate_point_calculator = BayesianMonteCarloPointCalculator(model, space)
loop_state = create_loop_state(model.X, model.Y)
loop_state = create_bq_loop_state(model.X, model.Y)

super().__init__(candidate_point_calculator, model_updater, loop_state)

Expand Down
8 changes: 4 additions & 4 deletions emukit/quadrature/loop/bq_loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from ...core.loop.user_function_result import UserFunctionResult


class BQLoopState(LoopState):
"""Contains the state of the BQ loop, which includes a history of all function evaluations and integral mean and
class QuadratureLoopState(LoopState):
"""Contains the state of the BQ loop, which includes a history of all integrand evaluations and integral mean and
variance estimates.
:param initial_results: The results from previous integrand evaluations.
Expand All @@ -35,7 +35,7 @@ def update_integral_stats(self, integral_mean: float, integral_var: float) -> No
self.integral_vars.append(integral_var)


def create_bq_loop_state(x_init: np.ndarray, y_init: np.ndarray, **kwargs) -> BQLoopState:
def create_bq_loop_state(x_init: np.ndarray, y_init: np.ndarray, **kwargs) -> QuadratureLoopState:
"""Creates a BQ loop state object using the provided data.
:param x_init: x values for initial function evaluations. Shape: (n_initial_points x n_input_dims)
Expand All @@ -45,4 +45,4 @@ def create_bq_loop_state(x_init: np.ndarray, y_init: np.ndarray, **kwargs) -> BQ
"""

loop_state = create_loop_state(x_init, y_init, **kwargs)
return BQLoopState(loop_state.results)
return QuadratureLoopState(loop_state.results)
8 changes: 4 additions & 4 deletions emukit/quadrature/loop/bq_outer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from ...core.loop import OuterLoop
from ...core.loop.candidate_point_calculators import CandidatePointCalculator
from ...core.loop.model_updaters import ModelUpdater
from .bq_loop_state import BQLoopState
from .bq_loop_state import QuadratureLoopState


class BQOuterLoop(OuterLoop):
"""Base class for a Bayesian quadrature outer loop.
class QuadratureOuterLoop(OuterLoop):
"""Base class for a Bayesian quadrature loop.
:param candidate_point_calculator: Finds next point(s) to evaluate.
:param model_updaters: Updates the model with the new data and fits the model hyper-parameters.
Expand All @@ -26,7 +26,7 @@ def __init__(
self,
candidate_point_calculator: CandidatePointCalculator,
model_updaters: Union[ModelUpdater, List[ModelUpdater]],
loop_state: BQLoopState = None,
loop_state: QuadratureLoopState = None,
):
if isinstance(model_updaters, list):
raise ValueError("The BQ loop only supports a single model.")
Expand Down
6 changes: 3 additions & 3 deletions emukit/quadrature/loop/bq_stopping_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from ...core.loop.stopping_conditions import StoppingCondition
from .bq_loop_state import BQLoopState
from .bq_loop_state import QuadratureLoopState

_log = logging.getLogger(__name__)

Expand All @@ -20,7 +20,7 @@ class CoefficientOfVariationStoppingCondition(StoppingCondition):
.. math::
COV = \frac{\sigma}{\mu}
with :math:`\mu` and :math:`\sigma^2` the current mean and standard deviation of integral according to the
where :math:`\mu` and :math:`\sigma^2` are the current mean and standard deviation of integral according to the
BQ posterior model.
:param eps: Threshold under which the COV must fall.
Expand All @@ -37,7 +37,7 @@ def __init__(self, eps: float, delay: int = 1) -> None:
self.delay = delay
self.times_true = 0 # counts how many times stopping had been triggered in a row

def should_stop(self, loop_state: BQLoopState) -> bool:
def should_stop(self, loop_state: QuadratureLoopState) -> bool:
if len(loop_state.integral_means) < 1:
return False

Expand Down
4 changes: 2 additions & 2 deletions emukit/quadrature/loop/vanilla_bq_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from ..acquisitions import IntegralVarianceReduction
from ..methods import VanillaBayesianQuadrature
from .bq_loop_state import create_bq_loop_state
from .bq_outer_loop import BQOuterLoop
from .bq_outer_loop import QuadratureOuterLoop


class VanillaBayesianQuadratureLoop(BQOuterLoop):
class VanillaBayesianQuadratureLoop(QuadratureOuterLoop):
"""The loop for standard ('vanilla') Bayesian Quadrature.
.. seealso::
Expand Down
4 changes: 2 additions & 2 deletions emukit/quadrature/loop/wsabil_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from ..acquisitions import UncertaintySampling
from ..methods import WSABIL
from .bq_loop_state import create_bq_loop_state
from .bq_outer_loop import BQOuterLoop
from .bq_outer_loop import QuadratureOuterLoop


class WSABILLoop(BQOuterLoop):
class WSABILLoop(QuadratureOuterLoop):
"""The loop for WSABI-L.
.. rubric:: References
Expand Down

0 comments on commit e24fb06

Please sign in to comment.