Skip to content

Commit

Permalink
Updated test_prior.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyuon committed Aug 2, 2024
1 parent ed727af commit 9d191da
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions test/unit/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def test_uniform(self):
samples = p.sample(jax.random.PRNGKey(0), 10000)
assert jnp.all(jnp.isfinite(samples['x']))
# Check that the log_prob is correct in the support
samples = p.sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.allclose(log_prob, jnp.log(1.0 / (xmax - xmin)))

Expand All @@ -40,7 +39,6 @@ def test_sine(self):
samples = p.sample(jax.random.PRNGKey(0), 10000)
assert jnp.all(jnp.isfinite(samples['x']))
# Check that the log_prob is finite
samples = p.sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.all(jnp.isfinite(log_prob))
# Check that the log_prob is correct in the support
Expand All @@ -56,7 +54,6 @@ def test_cosine(self):
samples = p.sample(jax.random.PRNGKey(0), 10000)
assert jnp.all(jnp.isfinite(samples['x']))
# Check that the log_prob is finite
samples = p.sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.all(jnp.isfinite(log_prob))
# Check that the log_prob is correct in the support
Expand All @@ -73,9 +70,6 @@ def test_uniform_sphere(self):
assert jnp.all(jnp.isfinite(samples['x_theta']))
assert jnp.all(jnp.isfinite(samples['x_phi']))
# Check that the log_prob is finite
samples = {}
for i in range(3):
samples.update(trace_prior_parent(p, [])[i].sample(jax.random.PRNGKey(0), 10000))
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.all(jnp.isfinite(log_prob))

Expand All @@ -97,14 +91,12 @@ def func(alpha):
assert jnp.all(jnp.isfinite(powerlaw_samples['x']))

# Check that all the log_probs are finite
samples = p.sample(jax.random.PRNGKey(0), 10000)
log_p = jax.vmap(p.log_prob, [0])(samples)
log_p = jax.vmap(p.log_prob, [0])(powerlaw_samples)
assert jnp.all(jnp.isfinite(log_p))

# Check that the log_prob is correct in the support
samples = p.sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
standard_log_prob = powerlaw_log_pdf(samples['x'], alpha, xmin, xmax)
log_prob = jax.vmap(p.log_prob)(powerlaw_samples)
standard_log_prob = powerlaw_log_pdf(powerlaw_samples['x'], alpha, xmin, xmax)
# log pdf of powerlaw
assert jnp.allclose(log_prob, standard_log_prob, atol=1e-4)

Expand Down

0 comments on commit 9d191da

Please sign in to comment.