diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index c68b4dc4..c42d76be 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -30,7 +30,6 @@ def test_uniform(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is correct in the support - samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.allclose(log_prob, jnp.log(1.0 / (xmax - xmin))) @@ -40,7 +39,6 @@ def test_sine(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite - samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) # Check that the log_prob is correct in the support @@ -56,7 +54,6 @@ def test_cosine(self): samples = p.sample(jax.random.PRNGKey(0), 10000) assert jnp.all(jnp.isfinite(samples['x'])) # Check that the log_prob is finite - samples = p.sample(jax.random.PRNGKey(0), 10000) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) # Check that the log_prob is correct in the support @@ -73,9 +70,6 @@ def test_uniform_sphere(self): assert jnp.all(jnp.isfinite(samples['x_theta'])) assert jnp.all(jnp.isfinite(samples['x_phi'])) # Check that the log_prob is finite - samples = {} - for i in range(3): - samples.update(trace_prior_parent(p, [])[i].sample(jax.random.PRNGKey(0), 10000)) log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) @@ -97,14 +91,12 @@ def func(alpha): assert jnp.all(jnp.isfinite(powerlaw_samples['x'])) # Check that all the log_probs are finite - samples = p.sample(jax.random.PRNGKey(0), 10000) - log_p = jax.vmap(p.log_prob, [0])(samples) + log_p = jax.vmap(p.log_prob, [0])(powerlaw_samples) assert jnp.all(jnp.isfinite(log_p)) # Check that the log_prob is correct in the support - samples = p.sample(jax.random.PRNGKey(0), 10000) - log_prob = jax.vmap(p.log_prob)(samples) - standard_log_prob = powerlaw_log_pdf(samples['x'], alpha, xmin, xmax) + log_prob = jax.vmap(p.log_prob)(powerlaw_samples) + standard_log_prob = powerlaw_log_pdf(powerlaw_samples['x'], alpha, xmin, xmax) # log pdf of powerlaw assert jnp.allclose(log_prob, standard_log_prob, atol=1e-4)