From fe9a38ad9755d80239a3bc3976c03b31e7661edc Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Thu, 31 Aug 2023 18:21:15 -0400 Subject: [PATCH] Update tests --- tests/test_acq.py | 12 ------------ tests/test_utils.py | 2 +- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/tests/test_acq.py b/tests/test_acq.py index f6a297e..3757c9e 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -131,18 +131,6 @@ def test_UCB_beta(): assert_(onp.array_equal(obj1, obj3)) -def test_KG_gp(): - rng_keys = get_keys() - X = onp.random.randn(8,) - X_new = onp.random.randn(12,) - y = 10 * X**2 - m = ExactGP(1, 'RBF') - m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) - obj = KG(m, X_new) - assert_(isinstance(obj, jnp.ndarray)) - assert_equal(obj.squeeze().shape, (len(X_new),)) - - def test_EI_gp_penalty_inv_distance(): rng_keys = get_keys() X = onp.random.randn(8,) diff --git a/tests/test_utils.py b/tests/test_utils.py index ff58537..65559f5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -90,4 +90,4 @@ def test_random_sample_difference(): sampled_data2 = random_sample_dict(data, num_samples, rng_key2) for key in sampled_data1: - assert_(jnp.array_equal(sampled_data1[key], sampled_data2[key])) + assert_(not jnp.array_equal(sampled_data1[key], sampled_data2[key]))