diff --git a/tests/test_acq.py b/tests/test_acq.py index 3757c9e..4763b1b 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -11,6 +11,7 @@ from gpax.models.gp import ExactGP from gpax.models.vidkl import viDKL +from gpax.models import DKL from gpax.utils import get_keys from gpax.acquisition.base_acq import ei, ucb, poi, ue, kg from gpax.acquisition.acquisition import _compute_mean_and_var @@ -105,7 +106,7 @@ def test_acq_gp(acq): @pytest.mark.parametrize("acq", [EI, UCB, UE, Thompson, POI, KG]) -def test_acq_dkl(acq): +def test_acq_vidkl(acq): rng_keys = get_keys() X = onp.random.randn(8, 10) X_new = onp.random.randn(12, 10) @@ -117,6 +118,18 @@ def test_acq_dkl(acq): assert_equal(obj.shape, (len(X_new),)) +@pytest.mark.parametrize("acq", [EI, POI, UCB]) +def test_acq_dkl(acq): + rng_keys = get_keys() + X = onp.random.randn(12, 8) + y = onp.random.randn(12,) + X_new = onp.random.randn(10, 8)[None] + m = DKL(X.shape[-1], 2, 'RBF') + m.fit(rng_keys[0], X, y, num_samples=5, num_warmup=5) + obj = acq(rng_keys[1], m, X_new, subsample_size=4) + assert_equal(obj.shape, (X_new.shape[1],)) + + def test_UCB_beta(): rng_keys = get_keys() X = onp.random.randn(8,) @@ -173,19 +186,6 @@ def test_UE_gp_penalty_inv_distance(): assert_(obj2[-2] < obj1[-2]) -@pytest.mark.parametrize("acq", [EI, UCB, UE, Thompson]) -def test_acq_dkl(acq): - rng_keys = get_keys() - X = onp.random.randn(32, 36) - y = onp.random.randn(32,) - X_new = onp.random.randn(10, 36) - m = viDKL(X.shape[-1]) - m.fit(rng_keys[0], X, y, num_steps=20, step_size=0.05) - obj = acq(rng_keys[1], m, X_new) - assert_(isinstance(obj, jnp.ndarray)) - assert_equal(obj.squeeze().shape, (len(X_new),)) - - @pytest.mark.parametrize("maximize_distance", [False, True]) def test_compute_batch_acquisition(maximize_distance): def mock_acq_fn(*args): @@ -201,7 +201,7 @@ def mock_acq_fn(*args): @pytest.mark.parametrize("q", [1, 3]) @pytest.mark.parametrize("acq", [qEI, qPOI, qUCB, qKG]) -def test_batched_acq(acq, q): +def test_batched_acq_gp(acq, q): rng_key = get_keys() X = onp.random.randn(8,) X_new = onp.random.randn(12,)