Skip to content

Commit

Permalink
docstring updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mmahsereci committed May 7, 2024
1 parent be501a2 commit addb62d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion emukit/quadrature/loop/bq_outer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ def __init__(
super().__init__(candidate_point_calculator, model_updaters, loop_state)

def _update_loop_state(self) -> None:
model = self.model_updaters[0].model # only works if there is a model, but for BQ nothing else makes sense
model = self.model_updaters[0].model # only works if there is one model, but for BQ nothing else makes sense
integral_mean, integral_var = model.integrate()
self.loop_state.update_integral_stats(integral_mean, integral_var)
16 changes: 9 additions & 7 deletions emukit/quadrature/loop/bq_stopping_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,28 @@ class CoefficientOfVariationStoppingCondition(StoppingCondition):
.. math::
COV = \frac{\sigma}{\mu}
where :math:`\mu` and :math:`\sigma^2` are the current mean and standard deviation of integral according to the
BQ posterior model.
where :math:`\mu` and :math:`\sigma^2` are the current mean and variance respectively of the integral according to
the BQ posterior model.
:param eps: Threshold under which the COV must fall.
:param delay: Number of times the stopping condition needs to be true in a row in order to stop. Defaults to 1.
:raises ValueError: If `delay` is smaller than 1.
:raises ValueError: If `eps` is non-negative.
"""

def __init__(self, eps: float, delay: int = 1) -> None:

if delay < 1:
raise ValueError(f"delay ({delay}) must be and integer greater than zero.")

if eps <= 0.0:
raise ValueError(f"eps ({eps}) must be positive.")

self.eps = eps
self.delay = delay
self.times_true = 0 # counts how many times stopping had been triggered in a row
self.times_true = 0 # counts how many times stopping has been triggered in a row

def should_stop(self, loop_state: QuadratureLoopState) -> bool:
if len(loop_state.integral_means) < 1:
Expand All @@ -50,11 +56,7 @@ def should_stop(self, loop_state: QuadratureLoopState) -> bool:
else:
self.times_true = 0

print(np.sqrt(v) / m)
print(should_stop, self.times_true)
should_stop = should_stop and (self.times_true >= self.delay)
print(should_stop)
print()

if should_stop:
_log.info(f"Stopped as coefficient of variation is below threshold of {self.eps}.")
Expand Down

0 comments on commit addb62d

Please sign in to comment.