Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Aug 20, 2023
1 parent cf52a6d commit 8fe9e24
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/test_acq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from gpax.utils import get_keys
from gpax.acquisition.base_acq import ei, ucb, poi
from gpax.acquisition import EI, UCB, UE, Thompson
from gpax.acquisition import qEI, qPOI, qUCB
from gpax.acquisition.penalties import compute_penalty


Expand Down Expand Up @@ -110,6 +111,19 @@ def test_acq_dkl(acq):
assert_equal(obj.squeeze().shape, (len(X_new),))


@pytest.mark.parametrize("q", [1, 3])
@pytest.mark.parametrize("acq", [qEI, qPOI, qUCB])
def test_batched_acq(acq, q):
rng_key = get_keys()[0]
X = onp.random.randn(8,)
X_new = onp.random.randn(12,)
y = 10 * X**2
m = ExactGP(1, 'RBF')
m.fit(rng_key, X, y, num_warmup=100, num_samples=100)
obj = acq(m, X_new, subsample_size=q)
assert_equal(obj.shape, (q, len(X_new)))


@pytest.mark.parametrize('pen', ['delta', 'inverse_distance'])
@pytest.mark.parametrize("acq", [EI, UCB, UE])
def test_acq_penalty_indices(acq, pen):
Expand Down

0 comments on commit 8fe9e24

Please sign in to comment.