Skip to content

Commit

Permalink
no nans now, but the code can be consolidate
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Aug 1, 2024
1 parent ff35a82 commit 47af9cf
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
13 changes: 11 additions & 2 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Jim(object):
# Name of parameters to sample from
sample_transforms: list[BijectiveTransform]
likelihood_transforms: list[NtoMTransform]
parameter_names: list[str]
sampler: Sampler

def __init__(
Expand All @@ -37,9 +38,14 @@ def __init__(

self.sample_transforms = sample_transforms
self.likelihood_transforms = likelihood_transforms
self.parameter_names = prior.parameter_names

if len(sample_transforms) == 0:
print("No sample transforms provided. Using prior parameters as sampling parameters")
else:
print("Using sample transforms")
for transform in sample_transforms:
self.parameter_names = transform.propagate_name(self.parameter_names)

if len(likelihood_transforms) == 0:
print("No likelihood transforms provided. Using prior parameters as likelihood parameters")
Expand Down Expand Up @@ -91,12 +97,15 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict):
prior = self.prior.log_prob(named_params) + transform_jacobian
for transform in self.likelihood_transforms:
named_params = transform.forward(named_params)
return self.likelihood.evaluate(named_params, data) + prior
named_params = jax.tree.map(lambda x:x[0], named_params) # This [0] should be consolidate
return self.likelihood.evaluate(named_params, data) + prior[0] # This prior [0] should be consolidate

def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])):
if initial_guess.size == 0:
initial_guess_named = self.prior.sample(key, self.Sampler.n_chains)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T
for transform in self.sample_transforms:
initial_guess_named = jax.vmap(transform.forward)(initial_guess_named)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T[0] # This [0] should be consolidate
self.Sampler.sample(initial_guess, None) # type: ignore

def maximize_likelihood(
Expand Down
5 changes: 3 additions & 2 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
for i in range(len(name_mapping[1]))
}

@jaxtyped(typechecker=typechecker)
class BoundToUnbound(BijectiveTransform):
"""
Bound to unbound transformation
Expand All @@ -319,8 +320,8 @@ def logit(x):
return jnp.log(x / (1 - x))

super().__init__(name_mapping)
self.original_lower_bound = original_lower_bound
self.original_upper_bound = original_upper_bound
self.original_lower_bound = jnp.atleast_1d(original_lower_bound)
self.original_upper_bound = jnp.atleast_1d(original_upper_bound)

self.transform_func = lambda x: {
name_mapping[1][i]: logit(
Expand Down
17 changes: 17 additions & 0 deletions test/integration/test_GW150914.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.transforms import BoundToUnbound
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -64,6 +65,21 @@
dec_prior,
]
)

sample_transforms = [
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0),
BoundToUnbound(name_mapping = [["eta"], ["eta_unbounded"]], original_lower_bound=0.125, original_upper_bound=0.25),
BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0),
BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05),
BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=-jnp.pi/2, original_upper_bound=jnp.pi/2),
BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=0.0, original_upper_bound=jnp.pi)
]

likelihood = TransientLikelihoodFD(
[H1, L1],
waveform=RippleIMRPhenomD(),
Expand All @@ -88,6 +104,7 @@
jim = Jim(
likelihood,
prior,
sample_transforms=sample_transforms,
n_loop_training=n_loop_training,
n_loop_production=1,
n_local_steps=5,
Expand Down

0 comments on commit 47af9cf

Please sign in to comment.