From 15b19b65808596e6de362c679d0ae562aeb5e40b Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 27 Aug 2023 12:25:21 -0400 Subject: [PATCH] Fix KG for minimization problems --- gpax/acquisition/base_acq.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/gpax/acquisition/base_acq.py b/gpax/acquisition/base_acq.py index 1d60673..23e1529 100644 --- a/gpax/acquisition/base_acq.py +++ b/gpax/acquisition/base_acq.py @@ -165,7 +165,7 @@ def poi(model: Type[ExactGP], def kg(model: Type[ExactGP], X_new: jnp.ndarray, sample: Dict[str, jnp.ndarray], - n: int = 1, + n: int = 10, maximize: bool = True, noiseless: bool = True, rng_key: Optional[jnp.ndarray] = None, @@ -178,7 +178,7 @@ def kg(model: Type[ExactGP], model: trained model X: new inputs with shape (N, D), where D is a feature dimension sample: a single sample with model parameters - n: Number fo simulated samples (Defaults to 1) + n: Number fo simulated samples (Defaults to 10) maximize: If True, assumes that BO is solving maximization problem noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed @@ -207,7 +207,10 @@ def kg_for_one_point(x_aug, y_aug, mean_o): y_fant = mean_aug.max() if maximize else mean_aug.min() # Compute adn return the improvement compared to the original maximum mean value mean_o_best = mean_o.max() if maximize else mean_o.min() - return y_fant - mean_o_best + u = y_fant - mean_o_best + if not maximize: + u = -u + return u # Get posterior distribution for candidate points mean, cov = model.get_mvn_posterior(X_new, *sample, noiseless=noiseless, **kwargs)