From addb62d33cdeee9c10e9baa6fddf9b5c8f97f43b Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Tue, 7 May 2024 17:55:09 +0200 Subject: [PATCH] docstring updates --- emukit/quadrature/loop/bq_outer_loop.py | 2 +- emukit/quadrature/loop/bq_stopping_conditions.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/emukit/quadrature/loop/bq_outer_loop.py b/emukit/quadrature/loop/bq_outer_loop.py index 02ab2191..49b30707 100644 --- a/emukit/quadrature/loop/bq_outer_loop.py +++ b/emukit/quadrature/loop/bq_outer_loop.py @@ -34,6 +34,6 @@ def __init__( super().__init__(candidate_point_calculator, model_updaters, loop_state) def _update_loop_state(self) -> None: - model = self.model_updaters[0].model # only works if there is a model, but for BQ nothing else makes sense + model = self.model_updaters[0].model # only works if there is one model, but for BQ nothing else makes sense integral_mean, integral_var = model.integrate() self.loop_state.update_integral_stats(integral_mean, integral_var) diff --git a/emukit/quadrature/loop/bq_stopping_conditions.py b/emukit/quadrature/loop/bq_stopping_conditions.py index 9cb49253..e4798a59 100644 --- a/emukit/quadrature/loop/bq_stopping_conditions.py +++ b/emukit/quadrature/loop/bq_stopping_conditions.py @@ -20,12 +20,15 @@ class CoefficientOfVariationStoppingCondition(StoppingCondition): .. math:: COV = \frac{\sigma}{\mu} - where :math:`\mu` and :math:`\sigma^2` are the current mean and standard deviation of integral according to the - BQ posterior model. + where :math:`\mu` and :math:`\sigma^2` are the current mean and variance respectively of the integral according to + the BQ posterior model. :param eps: Threshold under which the COV must fall. :param delay: Number of times the stopping condition needs to be true in a row in order to stop. Defaults to 1. + :raises ValueError: If `delay` is smaller than 1. + :raises ValueError: If `eps` is non-negative. + """ def __init__(self, eps: float, delay: int = 1) -> None: @@ -33,9 +36,12 @@ def __init__(self, eps: float, delay: int = 1) -> None: if delay < 1: raise ValueError(f"delay ({delay}) must be and integer greater than zero.") + if eps <= 0.0: + raise ValueError(f"eps ({eps}) must be positive.") + self.eps = eps self.delay = delay - self.times_true = 0 # counts how many times stopping had been triggered in a row + self.times_true = 0 # counts how many times stopping has been triggered in a row def should_stop(self, loop_state: QuadratureLoopState) -> bool: if len(loop_state.integral_means) < 1: @@ -50,11 +56,7 @@ def should_stop(self, loop_state: QuadratureLoopState) -> bool: else: self.times_true = 0 - print(np.sqrt(v) / m) - print(should_stop, self.times_true) should_stop = should_stop and (self.times_true >= self.delay) - print(should_stop) - print() if should_stop: _log.info(f"Stopped as coefficient of variation is below threshold of {self.eps}.")