Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Sep 4, 2023
1 parent e0af518 commit 609e12a
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions tests/test_acq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from gpax.models.gp import ExactGP
from gpax.models.vidkl import viDKL
from gpax.models import DKL
from gpax.utils import get_keys
from gpax.acquisition.base_acq import ei, ucb, poi, ue, kg
from gpax.acquisition.acquisition import _compute_mean_and_var
Expand Down Expand Up @@ -105,7 +106,7 @@ def test_acq_gp(acq):


@pytest.mark.parametrize("acq", [EI, UCB, UE, Thompson, POI, KG])
def test_acq_dkl(acq):
def test_acq_vidkl(acq):
rng_keys = get_keys()
X = onp.random.randn(8, 10)
X_new = onp.random.randn(12, 10)
Expand All @@ -117,6 +118,18 @@ def test_acq_dkl(acq):
assert_equal(obj.shape, (len(X_new),))


@pytest.mark.parametrize("acq", [EI, POI, UCB])
def test_acq_dkl(acq):
rng_keys = get_keys()
X = onp.random.randn(12, 8)
y = onp.random.randn(12,)
X_new = onp.random.randn(10, 8)[None]
m = DKL(X.shape[-1], 2, 'RBF')
m.fit(rng_keys[0], X, y, num_samples=5, num_warmup=5)
obj = acq(rng_keys[1], m, X_new, subsample_size=4)
assert_equal(obj.shape, (X_new.shape[1],))


def test_UCB_beta():
rng_keys = get_keys()
X = onp.random.randn(8,)
Expand Down Expand Up @@ -173,19 +186,6 @@ def test_UE_gp_penalty_inv_distance():
assert_(obj2[-2] < obj1[-2])


@pytest.mark.parametrize("acq", [EI, UCB, UE, Thompson])
def test_acq_dkl(acq):
rng_keys = get_keys()
X = onp.random.randn(32, 36)
y = onp.random.randn(32,)
X_new = onp.random.randn(10, 36)
m = viDKL(X.shape[-1])
m.fit(rng_keys[0], X, y, num_steps=20, step_size=0.05)
obj = acq(rng_keys[1], m, X_new)
assert_(isinstance(obj, jnp.ndarray))
assert_equal(obj.squeeze().shape, (len(X_new),))


@pytest.mark.parametrize("maximize_distance", [False, True])
def test_compute_batch_acquisition(maximize_distance):
def mock_acq_fn(*args):
Expand All @@ -201,7 +201,7 @@ def mock_acq_fn(*args):

@pytest.mark.parametrize("q", [1, 3])
@pytest.mark.parametrize("acq", [qEI, qPOI, qUCB, qKG])
def test_batched_acq(acq, q):
def test_batched_acq_gp(acq, q):
rng_key = get_keys()
X = onp.random.randn(8,)
X_new = onp.random.randn(12,)
Expand Down

0 comments on commit 609e12a

Please sign in to comment.