From 7bd5491f977aa3b407aad91c72b1a06efcb4f49c Mon Sep 17 00:00:00 2001 From: Zipeng Wang Date: Fri, 2 Feb 2024 10:38:29 -0500 Subject: [PATCH] Update prior.py --- src/jimgw/prior.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 384eb8e5..c42104db 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -65,8 +65,11 @@ def transform(self, x: dict[str, Float]) -> dict[str, Float]: A dictionary of parameters with the transforms applied. """ output = {} + #print("transform input:", x) for value in self.transforms.values(): output[value[0]] = value[1](x) + #print("transform output:", output) + return output def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: @@ -264,11 +267,17 @@ def sample( def log_prob(self, x: dict[str, Float]) -> Float: mag = x[self.naming[2]] + phi = x[self.naming[1]] output = jnp.where( (mag > 1) | (mag < 0), jnp.zeros_like(0) - jnp.inf, jnp.log(mag**2 * jnp.sin(x[self.naming[0]])), ) + output = jnp.where( + (phi > 2* jnp.pi) | (phi < 0), + jnp.zeros_like(0) - jnp.inf, + output, + ) return output