Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing PowerLaw #121

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
ScaleTransform,
OffsetTransform,
ArcSineTransform,
# PowerLawTransform,
# ParetoTransform,
PowerLawTransform,
ParetoTransform,
)


Expand Down
125 changes: 74 additions & 51 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,55 +368,78 @@ def __init__(
}


class PowerLawTransform(BijectiveTransform):
"""
PowerLaw transformation
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

xmin: Float
xmax: Float
alpha: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
xmin: Float,
xmax: Float,
alpha: Float,
):
super().__init__(name_mapping)
self.xmin = xmin
self.xmax = xmax
self.alpha = alpha
self.transform_func = lambda x: {
name_mapping[1][i]: (
self.xmin ** (1.0 + self.alpha)
+ x[name_mapping[0][i]]
* (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
)
** (1.0 / (1.0 + self.alpha))
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: (
(
x[name_mapping[1][i]] ** (1.0 + self.alpha)
- self.xmin ** (1.0 + self.alpha)
)
/ (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
)
for i in range(len(name_mapping[1]))
}

# class PowerLawTransform(UnivariateTransform):
# """
# PowerLaw transformation
# Parameters
# ----------
# name_mapping : tuple[list[str], list[str]]
# The name mapping between the input and output dictionary.
# """

# xmin: Float
# xmax: Float
# alpha: Float

# def __init__(
# self,
# name_mapping: tuple[list[str], list[str]],
# xmin: Float,
# xmax: Float,
# alpha: Float,
# ):
# super().__init__(name_mapping)
# self.xmin = xmin
# self.xmax = xmax
# self.alpha = alpha
# self.transform_func = lambda x: (
# self.xmin ** (1.0 + self.alpha)
# + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
# ) ** (1.0 / (1.0 + self.alpha))


# class ParetoTransform(UnivariateTransform):
# """
# Pareto transformation: Power law when alpha = -1
# Parameters
# ----------
# name_mapping : tuple[list[str], list[str]]
# The name mapping between the input and output dictionary.
# """

# def __init__(
# self,
# name_mapping: tuple[list[str], list[str]],
# xmin: Float,
# xmax: Float,
# ):
# super().__init__(name_mapping)
# self.xmin = xmin
# self.xmax = xmax
# self.transform_func = lambda x: self.xmin * jnp.exp(
# x * jnp.log(self.xmax / self.xmin)
# )

class ParetoTransform(BijectiveTransform):
"""
Pareto transformation: Power law when alpha = -1
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
xmin: Float,
xmax: Float,
):
super().__init__(name_mapping)
self.xmin = xmin
self.xmax = xmax
self.transform_func = lambda x: {
name_mapping[1][i]: self.xmin
* jnp.exp(x[name_mapping[0][i]] * jnp.log(self.xmax / self.xmin))
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: (
jnp.log(x[name_mapping[1][i]] / self.xmin)
/ jnp.log(self.xmax / self.xmin)
)
for i in range(len(name_mapping[1]))
}
29 changes: 9 additions & 20 deletions test/unit/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def test_uniform(self):
samples = p.sample(jax.random.PRNGKey(0), 10000)
assert jnp.all(jnp.isfinite(samples['x']))
# Check that the log_prob is correct in the support
samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.allclose(log_prob, jnp.log(1.0 / (xmax - xmin)))

Expand All @@ -40,30 +39,28 @@ def test_sine(self):
samples = p.sample(jax.random.PRNGKey(0), 10000)
assert jnp.all(jnp.isfinite(samples['x']))
# Check that the log_prob is finite
samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.all(jnp.isfinite(log_prob))
# Check that the log_prob is correct in the support
x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None])
y = jax.vmap(p.base_prior.base_prior.transform)(x)
y = jax.vmap(p.base_prior.transform)(y)
y = jax.vmap(p.transform)(y)
assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.sin(y['x'])/2.0))
assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.sin(y['x'])/2.0))

def test_cosine(self):
p = CosinePrior(["x"])
# Check that all the samples are finite
samples = p.sample(jax.random.PRNGKey(0), 10000)
assert jnp.all(jnp.isfinite(samples['x']))
# Check that the log_prob is finite
samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.all(jnp.isfinite(log_prob))
# Check that the log_prob is correct in the support
x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None])
y = jax.vmap(p.base_prior.transform)(x)
y = jax.vmap(p.transform)(y)
assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.cos(y['x'])/2.0))
assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.cos(y['x'])/2.0))

def test_uniform_sphere(self):
p = UniformSpherePrior(["x"])
Expand All @@ -73,12 +70,10 @@ def test_uniform_sphere(self):
assert jnp.all(jnp.isfinite(samples['x_theta']))
assert jnp.all(jnp.isfinite(samples['x_phi']))
# Check that the log_prob is finite
samples = {}
for i in range(3):
samples.update(trace_prior_parent(p, [])[i].sample(jax.random.PRNGKey(0), 10000))
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.all(jnp.isfinite(log_prob))


def test_power_law(self):
def powerlaw_log_pdf(x, alpha, xmin, xmax):
if alpha == -1.0:
Expand All @@ -96,20 +91,14 @@ def func(alpha):
assert jnp.all(jnp.isfinite(powerlaw_samples['x']))

# Check that all the log_probs are finite
samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x_base']
base_log_p = jax.vmap(p.log_prob, [0])({'x_base':samples})
assert jnp.all(jnp.isfinite(base_log_p))
log_p = jax.vmap(p.log_prob, [0])(powerlaw_samples)
assert jnp.all(jnp.isfinite(log_p))

# Check that the log_prob is correct in the support
samples = jnp.linspace(-10.0, 10.0, 1000)
transformed_samples = jax.vmap(p.transform)({'x_base': samples})['x']
# cut off the samples that are outside the support
samples = samples[transformed_samples >= xmin]
transformed_samples = transformed_samples[transformed_samples >= xmin]
samples = samples[transformed_samples <= xmax]
transformed_samples = transformed_samples[transformed_samples <= xmax]
log_prob = jax.vmap(p.log_prob)(powerlaw_samples)
standard_log_prob = powerlaw_log_pdf(powerlaw_samples['x'], alpha, xmin, xmax)
# log pdf of powerlaw
assert jnp.allclose(jax.vmap(p.log_prob)({'x_base':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4)
assert jnp.allclose(log_prob, standard_log_prob, atol=1e-4)

# Test Pareto Transform
func(-1.0)
Expand All @@ -120,4 +109,4 @@ def func(alpha):
func(alpha_val)
negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0]
for alpha_val in negative_alpha:
func(alpha_val)
func(alpha_val)
Loading