diff --git a/tests/test_acq.py b/tests/test_acq.py index f34ee31..98dbe58 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -12,6 +12,7 @@ from gpax.utils import get_keys from gpax.acquisition.base_acq import ei, ucb, poi from gpax.acquisition import EI, UCB, UE, Thompson +from gpax.acquisition import qEI, qPOI, qUCB from gpax.acquisition.penalties import compute_penalty @@ -110,6 +111,19 @@ def test_acq_dkl(acq): assert_equal(obj.squeeze().shape, (len(X_new),)) +@pytest.mark.parametrize("q", [1, 3]) +@pytest.mark.parametrize("acq", [qEI, qPOI, qUCB]) +def test_batched_acq(acq, q): + rng_key = get_keys()[0] + X = onp.random.randn(8,) + X_new = onp.random.randn(12,) + y = 10 * X**2 + m = ExactGP(1, 'RBF') + m.fit(rng_key, X, y, num_warmup=100, num_samples=100) + obj = acq(m, X_new, subsample_size=q) + assert_equal(obj.shape, (q, len(X_new))) + + @pytest.mark.parametrize('pen', ['delta', 'inverse_distance']) @pytest.mark.parametrize("acq", [EI, UCB, UE]) def test_acq_penalty_indices(acq, pen):