From 0edc0fe176113d1fa1c0b59beb17a7964c3c3ef3 Mon Sep 17 00:00:00 2001 From: Peter Pang Date: Sun, 4 Feb 2024 13:08:53 +0100 Subject: [PATCH 1/2] Fixing missing bound for AlignedSpin prior --- src/jimgw/prior.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 941cf933..2ab0bc1a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -292,6 +292,8 @@ class AlignedSpin(Prior): """ amax: Float = 0.99 + xmax: Float = 0.99 + xmin: Float = -0.99 chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) @@ -308,6 +310,8 @@ def __init__( super().__init__(naming, transforms) assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" self.amax = amax + self.xmax = amax + self.xmin = -amax # build the interpolation table for the ppf of the one-sided distribution chi_axis = jnp.linspace(1e-31, self.amax, num=1000) From 397b2e31444bcf6d67066841b608ea9e161ccc69 Mon Sep 17 00:00:00 2001 From: Peter Pang Date: Sun, 4 Feb 2024 15:35:17 +0100 Subject: [PATCH 2/2] Setting xmin and xmax as property for AlignedSpin --- src/jimgw/prior.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 2ab0bc1a..433b385a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -292,8 +292,6 @@ class AlignedSpin(Prior): """ amax: Float = 0.99 - xmax: Float = 0.99 - xmin: Float = -0.99 chi_axis: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) cdf_vals: Array = field(default_factory=lambda: jnp.linspace(0, 1, num=1000)) @@ -310,8 +308,6 @@ def __init__( super().__init__(naming, transforms) assert self.n_dim == 1, "Alignedspin needs to be 1D distributions" self.amax = amax - self.xmax = amax - self.xmin = -amax # build the interpolation table for the ppf of the one-sided distribution chi_axis = jnp.linspace(1e-31, self.amax, num=1000) @@ -319,6 +315,14 @@ def __init__( self.chi_axis = chi_axis self.cdf_vals = cdf_vals + @property + def xmin(self): + return -self.amax + + @property + def xmax(self): + return self.amax + def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, " n_samples"]]: