Skip to content

Commit

Permalink
Merge pull request #121 from xuyuon/98-moving-naming-tracking-into-ji…
Browse files Browse the repository at this point in the history
…m-class-from-prior-class

Fixing PowerLaw
  • Loading branch information
kazewong authored Aug 2, 2024
2 parents 47af9cf + 9d191da commit 957335e
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 73 deletions.
4 changes: 2 additions & 2 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
ScaleTransform,
OffsetTransform,
ArcSineTransform,
# PowerLawTransform,
# ParetoTransform,
PowerLawTransform,
ParetoTransform,
)


Expand Down
125 changes: 74 additions & 51 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,55 +368,78 @@ def __init__(
}


class PowerLawTransform(BijectiveTransform):
"""
PowerLaw transformation
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

xmin: Float
xmax: Float
alpha: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
xmin: Float,
xmax: Float,
alpha: Float,
):
super().__init__(name_mapping)
self.xmin = xmin
self.xmax = xmax
self.alpha = alpha
self.transform_func = lambda x: {
name_mapping[1][i]: (
self.xmin ** (1.0 + self.alpha)
+ x[name_mapping[0][i]]
* (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
)
** (1.0 / (1.0 + self.alpha))
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: (
(
x[name_mapping[1][i]] ** (1.0 + self.alpha)
- self.xmin ** (1.0 + self.alpha)
)
/ (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
)
for i in range(len(name_mapping[1]))
}

# class PowerLawTransform(UnivariateTransform):
# """
# PowerLaw transformation
# Parameters
# ----------
# name_mapping : tuple[list[str], list[str]]
# The name mapping between the input and output dictionary.
# """

# xmin: Float
# xmax: Float
# alpha: Float

# def __init__(
# self,
# name_mapping: tuple[list[str], list[str]],
# xmin: Float,
# xmax: Float,
# alpha: Float,
# ):
# super().__init__(name_mapping)
# self.xmin = xmin
# self.xmax = xmax
# self.alpha = alpha
# self.transform_func = lambda x: (
# self.xmin ** (1.0 + self.alpha)
# + x * (self.xmax ** (1.0 + self.alpha) - self.xmin ** (1.0 + self.alpha))
# ) ** (1.0 / (1.0 + self.alpha))


# class ParetoTransform(UnivariateTransform):
# """
# Pareto transformation: Power law when alpha = -1
# Parameters
# ----------
# name_mapping : tuple[list[str], list[str]]
# The name mapping between the input and output dictionary.
# """

# def __init__(
# self,
# name_mapping: tuple[list[str], list[str]],
# xmin: Float,
# xmax: Float,
# ):
# super().__init__(name_mapping)
# self.xmin = xmin
# self.xmax = xmax
# self.transform_func = lambda x: self.xmin * jnp.exp(
# x * jnp.log(self.xmax / self.xmin)
# )

class ParetoTransform(BijectiveTransform):
"""
Pareto transformation: Power law when alpha = -1
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
xmin: Float,
xmax: Float,
):
super().__init__(name_mapping)
self.xmin = xmin
self.xmax = xmax
self.transform_func = lambda x: {
name_mapping[1][i]: self.xmin
* jnp.exp(x[name_mapping[0][i]] * jnp.log(self.xmax / self.xmin))
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: (
jnp.log(x[name_mapping[1][i]] / self.xmin)
/ jnp.log(self.xmax / self.xmin)
)
for i in range(len(name_mapping[1]))
}
29 changes: 9 additions & 20 deletions test/unit/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def test_uniform(self):
samples = p.sample(jax.random.PRNGKey(0), 10000)
assert jnp.all(jnp.isfinite(samples['x']))
# Check that the log_prob is correct in the support
samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.allclose(log_prob, jnp.log(1.0 / (xmax - xmin)))

Expand All @@ -40,30 +39,28 @@ def test_sine(self):
samples = p.sample(jax.random.PRNGKey(0), 10000)
assert jnp.all(jnp.isfinite(samples['x']))
# Check that the log_prob is finite
samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.all(jnp.isfinite(log_prob))
# Check that the log_prob is correct in the support
x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None])
y = jax.vmap(p.base_prior.base_prior.transform)(x)
y = jax.vmap(p.base_prior.transform)(y)
y = jax.vmap(p.transform)(y)
assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.sin(y['x'])/2.0))
assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.sin(y['x'])/2.0))

def test_cosine(self):
p = CosinePrior(["x"])
# Check that all the samples are finite
samples = p.sample(jax.random.PRNGKey(0), 10000)
assert jnp.all(jnp.isfinite(samples['x']))
# Check that the log_prob is finite
samples = trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000)
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.all(jnp.isfinite(log_prob))
# Check that the log_prob is correct in the support
x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None])
y = jax.vmap(p.base_prior.transform)(x)
y = jax.vmap(p.transform)(y)
assert jnp.allclose(jax.vmap(p.log_prob)(x), jnp.log(jnp.cos(y['x'])/2.0))
assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.cos(y['x'])/2.0))

def test_uniform_sphere(self):
p = UniformSpherePrior(["x"])
Expand All @@ -73,12 +70,10 @@ def test_uniform_sphere(self):
assert jnp.all(jnp.isfinite(samples['x_theta']))
assert jnp.all(jnp.isfinite(samples['x_phi']))
# Check that the log_prob is finite
samples = {}
for i in range(3):
samples.update(trace_prior_parent(p, [])[i].sample(jax.random.PRNGKey(0), 10000))
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.all(jnp.isfinite(log_prob))


def test_power_law(self):
def powerlaw_log_pdf(x, alpha, xmin, xmax):
if alpha == -1.0:
Expand All @@ -96,20 +91,14 @@ def func(alpha):
assert jnp.all(jnp.isfinite(powerlaw_samples['x']))

# Check that all the log_probs are finite
samples = (trace_prior_parent(p, [])[0].sample(jax.random.PRNGKey(0), 10000))['x_base']
base_log_p = jax.vmap(p.log_prob, [0])({'x_base':samples})
assert jnp.all(jnp.isfinite(base_log_p))
log_p = jax.vmap(p.log_prob, [0])(powerlaw_samples)
assert jnp.all(jnp.isfinite(log_p))

# Check that the log_prob is correct in the support
samples = jnp.linspace(-10.0, 10.0, 1000)
transformed_samples = jax.vmap(p.transform)({'x_base': samples})['x']
# cut off the samples that are outside the support
samples = samples[transformed_samples >= xmin]
transformed_samples = transformed_samples[transformed_samples >= xmin]
samples = samples[transformed_samples <= xmax]
transformed_samples = transformed_samples[transformed_samples <= xmax]
log_prob = jax.vmap(p.log_prob)(powerlaw_samples)
standard_log_prob = powerlaw_log_pdf(powerlaw_samples['x'], alpha, xmin, xmax)
# log pdf of powerlaw
assert jnp.allclose(jax.vmap(p.log_prob)({'x_base':samples}), powerlaw_log_pdf(transformed_samples, alpha, xmin, xmax), atol=1e-4)
assert jnp.allclose(log_prob, standard_log_prob, atol=1e-4)

# Test Pareto Transform
func(-1.0)
Expand All @@ -120,4 +109,4 @@ def func(alpha):
func(alpha_val)
negative_alpha = [-0.5, -1.5, -2.0, -2.5, -3.0, -3.5, -4.0, -4.5, -5.0]
for alpha_val in negative_alpha:
func(alpha_val)
func(alpha_val)

0 comments on commit 957335e

Please sign in to comment.