diff --git a/tests/test_acq.py b/tests/test_acq.py index f6a297e..3757c9e 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -131,18 +131,6 @@ def test_UCB_beta(): assert_(onp.array_equal(obj1, obj3)) -def test_KG_gp(): - rng_keys = get_keys() - X = onp.random.randn(8,) - X_new = onp.random.randn(12,) - y = 10 * X**2 - m = ExactGP(1, 'RBF') - m.fit(rng_keys[0], X, y, num_warmup=100, num_samples=100) - obj = KG(m, X_new) - assert_(isinstance(obj, jnp.ndarray)) - assert_equal(obj.squeeze().shape, (len(X_new),)) - - def test_EI_gp_penalty_inv_distance(): rng_keys = get_keys() X = onp.random.randn(8,) diff --git a/tests/test_utils.py b/tests/test_utils.py index ff58537..65559f5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -90,4 +90,4 @@ def test_random_sample_difference(): sampled_data2 = random_sample_dict(data, num_samples, rng_key2) for key in sampled_data1: - assert_(jnp.array_equal(sampled_data1[key], sampled_data2[key])) + assert_(not jnp.array_equal(sampled_data1[key], sampled_data2[key]))