From 8fe9e2405d565d37a9c47364fc8b5306629c9f1e Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 20 Aug 2023 19:11:32 -0400 Subject: [PATCH] Update tests --- tests/test_acq.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_acq.py b/tests/test_acq.py index f34ee31..98dbe58 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -12,6 +12,7 @@ from gpax.utils import get_keys from gpax.acquisition.base_acq import ei, ucb, poi from gpax.acquisition import EI, UCB, UE, Thompson +from gpax.acquisition import qEI, qPOI, qUCB from gpax.acquisition.penalties import compute_penalty @@ -110,6 +111,19 @@ def test_acq_dkl(acq): assert_equal(obj.squeeze().shape, (len(X_new),)) +@pytest.mark.parametrize("q", [1, 3]) +@pytest.mark.parametrize("acq", [qEI, qPOI, qUCB]) +def test_batched_acq(acq, q): + rng_key = get_keys()[0] + X = onp.random.randn(8,) + X_new = onp.random.randn(12,) + y = 10 * X**2 + m = ExactGP(1, 'RBF') + m.fit(rng_key, X, y, num_warmup=100, num_samples=100) + obj = acq(m, X_new, subsample_size=q) + assert_equal(obj.shape, (q, len(X_new))) + + @pytest.mark.parametrize('pen', ['delta', 'inverse_distance']) @pytest.mark.parametrize("acq", [EI, UCB, UE]) def test_acq_penalty_indices(acq, pen):