diff --git a/tests/test_acq.py b/tests/test_acq.py index d205710..4d0c680 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -3,8 +3,6 @@ import numpy as onp import jax import jax.numpy as jnp -import numpyro -import numpyro.distributions as dist from numpy.testing import assert_equal, assert_ sys.path.insert(0, "../gpax/") diff --git a/tests/test_optimize_acq.py b/tests/test_optimize_acq.py new file mode 100644 index 0000000..078fbaa --- /dev/null +++ b/tests/test_optimize_acq.py @@ -0,0 +1,36 @@ +import sys +import pytest +import numpy as onp +import jax.numpy as jnp +from numpy.testing import assert_ + +sys.path.insert(0, "../gpax/") + +from gpax.models.gp import ExactGP +from gpax.acquisition.optimize import optimize_acq +from gpax.acquisition.acquisition import UCB, EI +from gpax.utils import get_keys + + +def get_inputs(): + X = onp.random.uniform(-2, 2, size=(4,)) + y = X**3 + return X, y + + +@pytest.mark.parametrize("acq_fn", [UCB, EI]) +def test_optimize_acq(acq_fn): + lower_bound = -2.0 + upper_bound = 2.0 + num_initial_guesses = 3 + key1, key2 = get_keys() + X, y = get_inputs() + model = ExactGP(1, 'RBF') + model.fit(key1, X, y, num_warmup=50, num_samples=50) + x_next = optimize_acq( + key2, model, acq_fn, num_initial_guesses, lower_bound, upper_bound) + assert_(isinstance(x_next, jnp.ndarray)) + + + + \ No newline at end of file