Skip to content

Commit

Permalink
Update prior.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong authored Feb 2, 2024
1 parent 7bd5491 commit 79f0cb4
Showing 1 changed file with 3 additions and 10 deletions.
13 changes: 3 additions & 10 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,8 @@ def transform(self, x: dict[str, Float]) -> dict[str, Float]:
A dictionary of parameters with the transforms applied.
"""
output = {}
#print("transform input:", x)
for value in self.transforms.values():
output[value[0]] = value[1](x)
#print("transform output:", output)

return output

def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]:
Expand Down Expand Up @@ -266,18 +263,14 @@ def sample(
return self.add_name(jnp.stack([theta, phi, mag], axis=1).T)

def log_prob(self, x: dict[str, Float]) -> Float:
mag = x[self.naming[2]]
theta = x[self.naming[0]]
phi = x[self.naming[1]]
mag = x[self.naming[2]]
output = jnp.where(
(mag > 1) | (mag < 0),
(mag > 1) | (mag < 0) | (phi > 2* jnp.pi) | (phi < 0) | (theta > 1) | (theta < -1),
jnp.zeros_like(0) - jnp.inf,
jnp.log(mag**2 * jnp.sin(x[self.naming[0]])),
)
output = jnp.where(
(phi > 2* jnp.pi) | (phi < 0),
jnp.zeros_like(0) - jnp.inf,
output,
)
return output


Expand Down

0 comments on commit 79f0cb4

Please sign in to comment.