Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Aug 15, 2023
1 parent 742ea51 commit 06fa254
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions tests/test_vidkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def get_dummy_image_data(jax_ndarray=True):

def get_dummy_vector_data(jax_ndarray=True):
X, y = get_dummy_data(jax_ndarray)
X = X[None].repeat(3, axis=0)
y = y[None].repeat(3, axis=0)
return X, y

Expand Down Expand Up @@ -87,7 +86,9 @@ def test_get_mvn_posterior():
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1)}
m = viDKL(X.shape[-1])
mean, cov = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params)
m.X_train = X
m.y_train = y
mean, cov = m.get_mvn_posterior(X_test, nn_params, kernel_params)
assert isinstance(mean, jnp.ndarray)
assert isinstance(cov, jnp.ndarray)
assert_equal(mean.shape, (X_test.shape[0],))
Expand All @@ -104,9 +105,11 @@ def test_get_mvn_posterior_noiseless():
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1)}
m = viDKL(X.shape[-1])
mean1, cov1 = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params, noiseless=False)
mean1_, cov1_ = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params, noiseless=False)
mean2, cov2 = m.get_mvn_posterior(X, y, X_test, nn_params, kernel_params, noiseless=True)
m.X_train = X
m.y_train = y
mean1, cov1 = m.get_mvn_posterior(X_test, nn_params, kernel_params, noiseless=False)
mean1_, cov1_ = m.get_mvn_posterior(X_test, nn_params, kernel_params, noiseless=False)
mean2, cov2 = m.get_mvn_posterior(X_test, nn_params, kernel_params, noiseless=True)
assert_array_equal(mean1, mean1_)
assert_array_equal(cov1, cov1_)
assert_array_equal(mean1, mean2)
Expand Down Expand Up @@ -165,7 +168,7 @@ def test_predict_vector():
X_test, _ = get_dummy_vector_data()
net = hk.transform(lambda x: MLP()(x))
clone = lambda x: net.init(rng_key, x)
nn_params = jax.vmap(clone)(X)
nn_params = jax.vmap(clone)(X[None].repeat(len(y), 0))
kernel_params = {"k_length": jnp.array([[1.0], [1.0], [1.0]]),
"k_scale": jnp.array([1.0, 1.0, 1.0]),
"noise": jnp.array([0.1, 0.1, 0.1])}
Expand All @@ -177,8 +180,8 @@ def test_predict_vector():
mean, var = m.predict(rng_key, X_test)
assert isinstance(mean, jnp.ndarray)
assert isinstance(var, jnp.ndarray)
assert_equal(mean.shape, X_test.shape[:-1])
assert_equal(var.shape, X_test.shape[:-1])
assert_equal(mean.shape, y.shape)
assert_equal(var.shape, y.shape)


def test_predict_in_batches_scalar():
Expand Down Expand Up @@ -208,7 +211,7 @@ def test_predict_in_batches_vector():
X_test, _ = get_dummy_vector_data()
net = hk.transform(lambda x: MLP()(x))
clone = lambda x: net.init(rng_key, x)
nn_params = jax.vmap(clone)(X)
nn_params = jax.vmap(clone)(X[None].repeat(len(y), 0))
kernel_params = {"k_length": jnp.array([[1.0], [1.0], [1.0]]),
"k_scale": jnp.array([1.0, 1.0, 1.0]),
"noise": jnp.array([0.1, 0.1, 0.1])}
Expand All @@ -220,8 +223,8 @@ def test_predict_in_batches_vector():
mean, var = m.predict_in_batches(rng_key, X_test, batch_size=10)
assert isinstance(mean, jnp.ndarray)
assert isinstance(var, jnp.ndarray)
assert_equal(mean.shape, X_test.shape[:-1])
assert_equal(var.shape, X_test.shape[:-1])
assert_equal(mean.shape, y.shape)
assert_equal(var.shape, y.shape)


def test_fit_predict_scalar():
Expand All @@ -246,8 +249,8 @@ def test_fit_predict_vector():
rng_key, X, y, X_test, num_steps=100, step_size=0.05, batch_size=10)
assert isinstance(mean, jnp.ndarray)
assert isinstance(var, jnp.ndarray)
assert_equal(mean.shape, X_test.shape[:-1])
assert_equal(var.shape, X_test.shape[:-1])
assert_equal(mean.shape, y.shape)
assert_equal(var.shape, y.shape)


def test_fit_predict_scalar_ensemble():
Expand All @@ -274,8 +277,8 @@ def test_fit_predict_vector_ensemble():
num_steps=100, step_size=0.05, batch_size=10)
assert isinstance(mean, jnp.ndarray)
assert isinstance(var, jnp.ndarray)
assert_equal(mean.shape, (2, *X_test.shape[:-1]))
assert_equal(var.shape, (2, *X_test.shape[:-1]))
assert_equal(mean.shape, (2, *y.shape))
assert_equal(var.shape, (2, *y.shape))


def test_fit_predict_scalar_ensemble_custom_net():
Expand Down

0 comments on commit 06fa254

Please sign in to comment.