Skip to content

Commit

Permalink
Added code to create and reverse transforms for more flexibility
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibeauWouters committed Aug 17, 2024
1 parent 1b79a3a commit 450772f
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 148 deletions.
202 changes: 67 additions & 135 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from astropy.time import Time

from jimgw.single_event.detector import GroundBased2G
from jimgw.transforms import BijectiveTransform, NtoNTransform
from jimgw.transforms import (
BijectiveTransform,
NtoNTransform,
reverse_bijective_transform,
create_bijective_transform,
)
from jimgw.single_event.utils import (
m1_m2_to_Mc_q,
Mc_q_to_m1_m2,
Expand All @@ -20,100 +25,65 @@


@jaxtyped(typechecker=typechecker)
class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform):
class SpinToCartesianSpinTransform(NtoNTransform):
"""
Transform chirp mass and mass ratio to component masses
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
Spin to Cartesian spin transformation
"""

freq_ref: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
freq_ref: Float,
):
super().__init__(name_mapping)
assert (
"m_1" in name_mapping[0]
and "m_2" in name_mapping[0]
and "M_c" in name_mapping[1]
and "q" in name_mapping[1]
)

def named_transform(x):
Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"])
return {"M_c": Mc, "q": q}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform mass ratio to symmetric mass ratio
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.

"""
self.freq_ref = freq_ref

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
assert (
"m_1" in name_mapping[0]
and "m_2" in name_mapping[0]
and "M_c" in name_mapping[1]
and "eta" in name_mapping[1]
"theta_jn" in name_mapping[0]
and "phi_jl" in name_mapping[0]
and "theta_1" in name_mapping[0]
and "theta_2" in name_mapping[0]
and "phi_12" in name_mapping[0]
and "a_1" in name_mapping[0]
and "a_2" in name_mapping[0]
and "iota" in name_mapping[1]
and "s1_x" in name_mapping[1]
and "s1_y" in name_mapping[1]
and "s1_z" in name_mapping[1]
and "s2_x" in name_mapping[1]
and "s2_y" in name_mapping[1]
and "s2_z" in name_mapping[1]
)

def named_transform(x):
Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"])
return {"M_c": Mc, "eta": eta}
iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
x["theta_2"],
x["phi_12"],
x["a_1"],
x["a_2"],
x["M_c"],
x["q"],
self.freq_ref,
x["phase_c"],
)
return {
"iota": iota,
"s1_x": s1x,
"s1_y": s1y,
"s1_z": s1z,
"s2_x": s2x,
"s2_y": s2y,
"s2_z": s2z,
}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class MassRatioToSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform mass ratio to symmetric mass ratio
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
assert "q" == name_mapping[0][0] and "eta" == name_mapping[1][0]

self.transform_func = lambda x: {"eta": q_to_eta(x["q"])}
self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])}


@jaxtyped(typechecker=typechecker)
class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform):
Expand Down Expand Up @@ -170,62 +140,24 @@ def named_inverse_transform(x):
self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class SpinToCartesianSpinTransform(NtoNTransform):
"""
Spin to Cartesian spin transformation
"""

freq_ref: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
freq_ref: Float,
):
super().__init__(name_mapping)

self.freq_ref = freq_ref

assert (
"theta_jn" in name_mapping[0]
and "phi_jl" in name_mapping[0]
and "theta_1" in name_mapping[0]
and "theta_2" in name_mapping[0]
and "phi_12" in name_mapping[0]
and "a_1" in name_mapping[0]
and "a_2" in name_mapping[0]
and "iota" in name_mapping[1]
and "s1_x" in name_mapping[1]
and "s1_y" in name_mapping[1]
and "s1_z" in name_mapping[1]
and "s2_x" in name_mapping[1]
and "s2_y" in name_mapping[1]
and "s2_z" in name_mapping[1]
)
# Pre-made bijective transforms:
ComponentMassesToChirpMassMassRatioTransform = create_bijective_transform(
(["m_1", "m_2"], ["M_c", "q"]), m1_m2_to_Mc_q, Mc_q_to_m1_m2
)
ChirpMassMassRatioToComponentMassesTransform = reverse_bijective_transform(
ComponentMassesToChirpMassMassRatioTransform
)

def named_transform(x):
iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
x["theta_2"],
x["phi_12"],
x["a_1"],
x["a_2"],
x["M_c"],
x["q"],
self.freq_ref,
x["phase_c"],
)
return {
"iota": iota,
"s1_x": s1x,
"s1_y": s1y,
"s1_z": s1z,
"s2_x": s2x,
"s2_y": s2y,
"s2_z": s2z,
}
ComponentMassesToChirpMassSymmetricMassRatioTransform = create_bijective_transform(
(["m_1", "m_2"], ["M_c", "eta"]), m1_m2_to_Mc_eta, Mc_eta_to_m1_m2
)
ChirpMassSymmetricMassRatioToComponentMassesTransform = reverse_bijective_transform(
ComponentMassesToChirpMassSymmetricMassRatioTransform
)

self.transform_func = named_transform
MassRatioToSymmetricMassRatioTransform = create_bijective_transform(
(["q"], ["eta"]), q_to_eta, eta_to_q
)
SymmetricMassRatioToMassRatioTransform = reverse_bijective_transform(
MassRatioToSymmetricMassRatioTransform
)
51 changes: 51 additions & 0 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,54 @@ def __init__(
)
for i in range(len(name_mapping[1]))
}


def create_bijective_transform(
name_mapping: tuple[list[str], list[str]],
transform_func_array: Callable[[Float], Float],
inverse_transform_func_array: Callable[[Float], Float],
) -> BijectiveTransform:
"""
Utility function to create a BijectiveTransform object given a name_mapping and the forward and backward transform functions which take arrays as input, e.g. coming from the utils module.
Args:
name_mapping (tuple[list[str], list[str]]): The name_mapping to be used in the named transforms.
transform_func_array (Callable[[Float], Float]): The forward function method taking an array as input.
inverse_transform_func_array (Callable[[Float], Float]): The inverse function method taking an array as input.
Returns:
BijectiveTransform: The BijectiveTransform object.
"""

def named_transform_func(x_named: dict[str, Float]) -> dict[str, Float]:
x_array = jnp.array([x_named[key] for key in name_mapping[0]])
y_array = transform_func_array(*x_array)
y_named = dict(zip(name_mapping[1], y_array))
return y_named

def named_inverse_transform_func(y_named: dict[str, Float]) -> dict[str, Float]:
y_array = jnp.array([y_named[key] for key in name_mapping[1]])
x_array = inverse_transform_func_array(*y_array)
x_named = dict(zip(name_mapping[0], x_array))
return x_named

new_transform = BijectiveTransform(name_mapping)
new_transform.transform_func = named_transform_func
new_transform.inverse_transform_func = named_inverse_transform_func

return new_transform


def reverse_bijective_transform(
original_transform: BijectiveTransform,
) -> BijectiveTransform:

reversed_name_mapping = (
original_transform.name_mapping[1],
original_transform.name_mapping[0],
)
reversed_transform = BijectiveTransform(name_mapping=reversed_name_mapping)
reversed_transform.transform_func = original_transform.inverse_transform_func
reversed_transform.inverse_transform_func = original_transform.transform_func

return reversed_transform
2 changes: 2 additions & 0 deletions test/integration/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
outdir/
figures/
37 changes: 24 additions & 13 deletions test/integration/test_GW150914_D.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10"

import jax
import jax.numpy as jnp

Expand All @@ -10,6 +14,8 @@
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam
from flowMC.utils.postprocessing import plot_summary
import optax

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -62,7 +68,7 @@
)

sample_transforms = [
ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]),
ComponentMassesToChirpMassMassRatioTransform,
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
Expand All @@ -78,7 +84,7 @@
]

likelihood_transforms = [
ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]),
ComponentMassesToChirpMassSymmetricMassRatioTransform,
]

likelihood = TransientLikelihoodFD(
Expand All @@ -97,34 +103,39 @@

Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1)

n_epochs = 2
n_loop_training = 1
learning_rate = 1e-4

n_epochs = 20
n_loop_training = 10
total_epochs = n_epochs * n_loop_training
start = total_epochs//10
learning_rate = optax.polynomial_schedule(
1e-3, 5e-4, 4.0, total_epochs - start, transition_begin=start
)

jim = Jim(
likelihood,
prior,
sample_transforms=sample_transforms,
likelihood_transforms=likelihood_transforms,
n_loop_training=n_loop_training,
n_loop_production=1,
n_local_steps=5,
n_global_steps=5,
n_chains=4,
n_loop_production=4,
n_local_steps=10,
n_global_steps=1000,
n_chains=500,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30,
n_flow_samples=100,
n_max_examples=30000,
n_flow_samples=100000,
momentum=0.9,
batch_size=100,
batch_size=30000,
use_global=True,
train_thinning=1,
output_thinning=1,
output_thinning=10,
local_sampler_arg=local_sampler_arg,
strategies=[Adam_optimizer, "default"],
)

jim.sample(jax.random.PRNGKey(42))
jim.get_samples()
jim.print_summary()
plot_summary(jim.sampler)
Loading

0 comments on commit 450772f

Please sign in to comment.