From 450772f72f914e9c5b49c2ef9e66d6b56e0b3a62 Mon Sep 17 00:00:00 2001 From: Thibeau Wouters Date: Sat, 17 Aug 2024 01:36:45 -0700 Subject: [PATCH] Added code to create and reverse transforms for more flexibility --- src/jimgw/single_event/transforms.py | 202 ++++++++--------------- src/jimgw/transforms.py | 51 ++++++ test/integration/.gitignore | 2 + test/integration/test_GW150914_D.py | 37 +++-- test/integration/test_mass_transforms.py | 127 ++++++++++++++ 5 files changed, 271 insertions(+), 148 deletions(-) create mode 100644 test/integration/.gitignore create mode 100644 test/integration/test_mass_transforms.py diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index c3e77846..49f01457 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -4,7 +4,12 @@ from astropy.time import Time from jimgw.single_event.detector import GroundBased2G -from jimgw.transforms import BijectiveTransform, NtoNTransform +from jimgw.transforms import ( + BijectiveTransform, + NtoNTransform, + reverse_bijective_transform, + create_bijective_transform, +) from jimgw.single_event.utils import ( m1_m2_to_Mc_q, Mc_q_to_m1_m2, @@ -20,100 +25,65 @@ @jaxtyped(typechecker=typechecker) -class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): +class SpinToCartesianSpinTransform(NtoNTransform): """ - Transform chirp mass and mass ratio to component masses - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. + Spin to Cartesian spin transformation """ + freq_ref: Float + def __init__( self, name_mapping: tuple[list[str], list[str]], + freq_ref: Float, ): super().__init__(name_mapping) - assert ( - "m_1" in name_mapping[0] - and "m_2" in name_mapping[0] - and "M_c" in name_mapping[1] - and "q" in name_mapping[1] - ) - - def named_transform(x): - Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) - return {"M_c": Mc, "q": q} - - self.transform_func = named_transform - - def named_inverse_transform(x): - m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) - return {"m_1": m1, "m_2": m2} - - self.inverse_transform_func = named_inverse_transform - - -@jaxtyped(typechecker=typechecker) -class ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform): - """ - Transform mass ratio to symmetric mass ratio - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ + self.freq_ref = freq_ref - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) assert ( - "m_1" in name_mapping[0] - and "m_2" in name_mapping[0] - and "M_c" in name_mapping[1] - and "eta" in name_mapping[1] + "theta_jn" in name_mapping[0] + and "phi_jl" in name_mapping[0] + and "theta_1" in name_mapping[0] + and "theta_2" in name_mapping[0] + and "phi_12" in name_mapping[0] + and "a_1" in name_mapping[0] + and "a_2" in name_mapping[0] + and "iota" in name_mapping[1] + and "s1_x" in name_mapping[1] + and "s1_y" in name_mapping[1] + and "s1_z" in name_mapping[1] + and "s2_x" in name_mapping[1] + and "s2_y" in name_mapping[1] + and "s2_z" in name_mapping[1] ) def named_transform(x): - Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) - return {"M_c": Mc, "eta": eta} + iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( + x["theta_jn"], + x["phi_jl"], + x["theta_1"], + x["theta_2"], + x["phi_12"], + x["a_1"], + x["a_2"], + x["M_c"], + x["q"], + self.freq_ref, + x["phase_c"], + ) + return { + "iota": iota, + "s1_x": s1x, + "s1_y": s1y, + "s1_z": s1z, + "s2_x": s2x, + "s2_y": s2y, + "s2_z": s2z, + } self.transform_func = named_transform - def named_inverse_transform(x): - m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"]) - return {"m_1": m1, "m_2": m2} - - self.inverse_transform_func = named_inverse_transform - - -@jaxtyped(typechecker=typechecker) -class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): - """ - Transform mass ratio to symmetric mass ratio - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - assert "q" == name_mapping[0][0] and "eta" == name_mapping[1][0] - - self.transform_func = lambda x: {"eta": q_to_eta(x["q"])} - self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])} - @jaxtyped(typechecker=typechecker) class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): @@ -170,62 +140,24 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform -@jaxtyped(typechecker=typechecker) -class SpinToCartesianSpinTransform(NtoNTransform): - """ - Spin to Cartesian spin transformation - """ - - freq_ref: Float - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - freq_ref: Float, - ): - super().__init__(name_mapping) - - self.freq_ref = freq_ref - - assert ( - "theta_jn" in name_mapping[0] - and "phi_jl" in name_mapping[0] - and "theta_1" in name_mapping[0] - and "theta_2" in name_mapping[0] - and "phi_12" in name_mapping[0] - and "a_1" in name_mapping[0] - and "a_2" in name_mapping[0] - and "iota" in name_mapping[1] - and "s1_x" in name_mapping[1] - and "s1_y" in name_mapping[1] - and "s1_z" in name_mapping[1] - and "s2_x" in name_mapping[1] - and "s2_y" in name_mapping[1] - and "s2_z" in name_mapping[1] - ) +# Pre-made bijective transforms: +ComponentMassesToChirpMassMassRatioTransform = create_bijective_transform( + (["m_1", "m_2"], ["M_c", "q"]), m1_m2_to_Mc_q, Mc_q_to_m1_m2 +) +ChirpMassMassRatioToComponentMassesTransform = reverse_bijective_transform( + ComponentMassesToChirpMassMassRatioTransform +) - def named_transform(x): - iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( - x["theta_jn"], - x["phi_jl"], - x["theta_1"], - x["theta_2"], - x["phi_12"], - x["a_1"], - x["a_2"], - x["M_c"], - x["q"], - self.freq_ref, - x["phase_c"], - ) - return { - "iota": iota, - "s1_x": s1x, - "s1_y": s1y, - "s1_z": s1z, - "s2_x": s2x, - "s2_y": s2y, - "s2_z": s2z, - } +ComponentMassesToChirpMassSymmetricMassRatioTransform = create_bijective_transform( + (["m_1", "m_2"], ["M_c", "eta"]), m1_m2_to_Mc_eta, Mc_eta_to_m1_m2 +) +ChirpMassSymmetricMassRatioToComponentMassesTransform = reverse_bijective_transform( + ComponentMassesToChirpMassSymmetricMassRatioTransform +) - self.transform_func = named_transform +MassRatioToSymmetricMassRatioTransform = create_bijective_transform( + (["q"], ["eta"]), q_to_eta, eta_to_q +) +SymmetricMassRatioToMassRatioTransform = reverse_bijective_transform( + MassRatioToSymmetricMassRatioTransform +) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 715d49de..44f2837c 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -445,3 +445,54 @@ def __init__( ) for i in range(len(name_mapping[1])) } + + +def create_bijective_transform( + name_mapping: tuple[list[str], list[str]], + transform_func_array: Callable[[Float], Float], + inverse_transform_func_array: Callable[[Float], Float], +) -> BijectiveTransform: + """ + Utility function to create a BijectiveTransform object given a name_mapping and the forward and backward transform functions which take arrays as input, e.g. coming from the utils module. + + Args: + name_mapping (tuple[list[str], list[str]]): The name_mapping to be used in the named transforms. + transform_func_array (Callable[[Float], Float]): The forward function method taking an array as input. + inverse_transform_func_array (Callable[[Float], Float]): The inverse function method taking an array as input. + + Returns: + BijectiveTransform: The BijectiveTransform object. + """ + + def named_transform_func(x_named: dict[str, Float]) -> dict[str, Float]: + x_array = jnp.array([x_named[key] for key in name_mapping[0]]) + y_array = transform_func_array(*x_array) + y_named = dict(zip(name_mapping[1], y_array)) + return y_named + + def named_inverse_transform_func(y_named: dict[str, Float]) -> dict[str, Float]: + y_array = jnp.array([y_named[key] for key in name_mapping[1]]) + x_array = inverse_transform_func_array(*y_array) + x_named = dict(zip(name_mapping[0], x_array)) + return x_named + + new_transform = BijectiveTransform(name_mapping) + new_transform.transform_func = named_transform_func + new_transform.inverse_transform_func = named_inverse_transform_func + + return new_transform + + +def reverse_bijective_transform( + original_transform: BijectiveTransform, +) -> BijectiveTransform: + + reversed_name_mapping = ( + original_transform.name_mapping[1], + original_transform.name_mapping[0], + ) + reversed_transform = BijectiveTransform(name_mapping=reversed_name_mapping) + reversed_transform.transform_func = original_transform.inverse_transform_func + reversed_transform.inverse_transform_func = original_transform.transform_func + + return reversed_transform diff --git a/test/integration/.gitignore b/test/integration/.gitignore new file mode 100644 index 00000000..a7f7ef0e --- /dev/null +++ b/test/integration/.gitignore @@ -0,0 +1,2 @@ +outdir/ +figures/ diff --git a/test/integration/test_GW150914_D.py b/test/integration/test_GW150914_D.py index e1eee9ac..d3c5c0e8 100644 --- a/test/integration/test_GW150914_D.py +++ b/test/integration/test_GW150914_D.py @@ -1,3 +1,7 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" + import jax import jax.numpy as jnp @@ -10,6 +14,8 @@ from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam +from flowMC.utils.postprocessing import plot_summary +import optax jax.config.update("jax_enable_x64", True) @@ -62,7 +68,7 @@ ) sample_transforms = [ - ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + ComponentMassesToChirpMassMassRatioTransform, BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), @@ -78,7 +84,7 @@ ] likelihood_transforms = [ - ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), + ComponentMassesToChirpMassSymmetricMassRatioTransform, ] likelihood = TransientLikelihoodFD( @@ -97,10 +103,14 @@ Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) -n_epochs = 2 -n_loop_training = 1 -learning_rate = 1e-4 +n_epochs = 20 +n_loop_training = 10 +total_epochs = n_epochs * n_loop_training +start = total_epochs//10 +learning_rate = optax.polynomial_schedule( + 1e-3, 5e-4, 4.0, total_epochs - start, transition_begin=start +) jim = Jim( likelihood, @@ -108,19 +118,19 @@ sample_transforms=sample_transforms, likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, - n_loop_production=1, - n_local_steps=5, - n_global_steps=5, - n_chains=4, + n_loop_production=4, + n_local_steps=10, + n_global_steps=1000, + n_chains=500, n_epochs=n_epochs, learning_rate=learning_rate, - n_max_examples=30, - n_flow_samples=100, + n_max_examples=30000, + n_flow_samples=100000, momentum=0.9, - batch_size=100, + batch_size=30000, use_global=True, train_thinning=1, - output_thinning=1, + output_thinning=10, local_sampler_arg=local_sampler_arg, strategies=[Adam_optimizer, "default"], ) @@ -128,3 +138,4 @@ jim.sample(jax.random.PRNGKey(42)) jim.get_samples() jim.print_summary() +plot_summary(jim.sampler) \ No newline at end of file diff --git a/test/integration/test_mass_transforms.py b/test/integration/test_mass_transforms.py new file mode 100644 index 00000000..ec7631cb --- /dev/null +++ b/test/integration/test_mass_transforms.py @@ -0,0 +1,127 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "2" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" + +import numpy as np +import matplotlib.pyplot as plt +import corner +import jax +import jax.numpy as jnp +from jaxtyping import Float + +from jimgw.prior import UniformPrior, CombinePrior +from jimgw.single_event.transforms import ChirpMassMassRatioToComponentMassesTransform +from jimgw.base import LikelihoodBase +from jimgw.jim import Jim + +params = {"axes.grid": True, + "text.usetex" : True, + "font.family" : "serif", + "ytick.color" : "black", + "xtick.color" : "black", + "axes.labelcolor" : "black", + "axes.edgecolor" : "black", + "font.serif" : ["Computer Modern Serif"], + "xtick.labelsize": 16, + "ytick.labelsize": 16, + "axes.labelsize": 16, + "legend.fontsize": 16, + "legend.title_fontsize": 16, + "figure.titlesize": 16} + +plt.rcParams.update(params) + +# Improved corner kwargs +default_corner_kwargs = dict(bins=40, + smooth=1., + show_titles=False, + label_kwargs=dict(fontsize=16), + title_kwargs=dict(fontsize=16), + color="blue", + # quantiles=[], + # levels=[0.9], + plot_density=True, + plot_datapoints=False, + fill_contours=True, + max_n_ticks=4, + min_n_ticks=3, + truth_color = "red", + save=False) + +# Likelihood for this test: + +class MyLikelihood(LikelihoodBase): + """Simple toy likelihood: Gaussian centered on the true component masses""" + + true_m1: Float + true_m2: Float + + def __init__(self, + true_m1: Float, + true_m2: Float): + + self.true_m1 = true_m1 + self.true_m2 = true_m2 + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + m1, m2 = params['m_1'], params['m_2'] + m1_std = 0.1 + m2_std = 0.1 + return -0.5 * (((m1 - self.true_m1) / m1_std)**2 + ((m2 - self.true_m2) / m2_std)**2) + +# Setup +true_m1 = 1.6 +true_m2 = 1.4 +true_mc = (true_m1 * true_m2)**(3/5) / (true_m1 + true_m2)**(1/5) +true_q = true_m2 / true_m1 + +# Priors +eps = 0.5 # half of width of the chirp mass prior +mc_prior = UniformPrior(true_mc - eps, true_mc + eps, parameter_names=['M_c']) +q_prior = UniformPrior(0.125, 1.0, parameter_names=['q']) +combine_prior = CombinePrior([mc_prior, q_prior]) + +# Likelihood and transform +likelihood = MyLikelihood(true_m1, true_m2) +mass_transform = ChirpMassMassRatioToComponentMassesTransform + +print(mass_transform.name_mapping) + +# Other stuff we have to give to Jim to make it work +step = 5e-3 +local_sampler_arg = {"step_size": step * jnp.eye(2)} + +# Jim: +jim = Jim(likelihood, + combine_prior, + likelihood_transforms=[mass_transform], + n_chains = 200, + parameter_names=['M_c', 'q'], + n_loop_training=20, + n_loop_production=5, + local_sampler_arg=local_sampler_arg) + +jim.sample(jax.random.PRNGKey(0)) +jim.print_summary() + +# Go from Mc, q samples to m1, m2 samples +chains_named = jim.get_samples() +m1m2_named = mass_transform.forward(chains_named) +m1, m2 = m1m2_named['m_1'], m1m2_named['m_2'] + +### Prior space: +chains = np.array([chains_named['M_c'], chains_named['q']]).T +chains = np.reshape(chains, (-1, 2)) +corner.corner(chains, truths = np.array([true_mc, true_q]), **default_corner_kwargs) + +plt.savefig("./figures/test_mass_transform_before.png", bbox_inches = 'tight') +plt.close() + +### Transformed space: +chains = np.array([m1, m2]).T +chains = np.reshape(chains, (-1, 2)) + +corner.corner(chains, truths = np.array([true_m1, true_m2]), **default_corner_kwargs) + +plt.savefig("./figures/test_mass_transform_after.png", bbox_inches = 'tight') +plt.close() \ No newline at end of file