-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Examples of single event runs #133
Closed
xuyuon
wants to merge
22
commits into
kazewong:jim-dev
from
xuyuon:98-moving-naming-tracking-into-jim-class-from-prior-class
Closed
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
5e2792c
Updated Pv2 scripts
xuyuon 420f663
Updated Pv2 scripts
xuyuon aaded8e
Updated GW150914.py
xuyuon 316e7c8
Updated GW150914_heterodyne.py
xuyuon d06de0d
Updated test_GW150914_Pv2.py
xuyuon b09ed0c
Updated GW150914_Pv2.py
xuyuon f85787c
Fixed PowerLaw Lowerbound error
xuyuon d34fd52
Updated GW150914_Pv2.py
xuyuon 250a344
Rename GW150914.py to GW150914_D.py
xuyuon f3f462d
Rename GW150914_PV2.py to GW150914_Pv2.py
xuyuon 40339d9
Rename GW150914_heterodyne.py to GW150914_D_heterodyne.py
xuyuon db768ea
Updated GW150914_Pv2.py
xuyuon b452856
Updated GW150914_Pv2.py
xuyuon 814c204
Fixed chain processing error
xuyuon 645dea1
Updated GW150914_D_heterodyne.py
xuyuon 8e014ee
Updated prior.py
xuyuon bc8e7dc
Update transforms.py
xuyuon b1133d3
Update runManager.py
xuyuon 2a9d696
Update runManager.py
xuyuon 871fb21
Merge pull request #7 from kazewong/98-moving-naming-tracking-into-ji…
xuyuon 2adcabf
Updated test_prior.py
xuyuon bb26fae
Merge pull request #8 from xuyuon/fix-test
xuyuon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from jimgw.jim import Jim | ||
from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior | ||
from jimgw.single_event.detector import H1, L1 | ||
from jimgw.single_event.likelihood import TransientLikelihoodFD | ||
from jimgw.single_event.waveform import RippleIMRPhenomD | ||
from jimgw.transforms import BoundToUnbound | ||
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform | ||
from jimgw.single_event.utils import Mc_q_to_m1_m2 | ||
from flowMC.strategy.optimization import optimization_Adam | ||
|
||
jax.config.update("jax_enable_x64", True) | ||
|
||
########################################### | ||
########## First we grab data ############# | ||
########################################### | ||
|
||
# first, fetch a 4s segment centered on GW150914 | ||
gps = 1126259462.4 | ||
duration = 4 | ||
post_trigger_duration = 2 | ||
start_pad = duration - post_trigger_duration | ||
end_pad = post_trigger_duration | ||
fmin = 20.0 | ||
fmax = 1024.0 | ||
|
||
ifos = [H1, L1] | ||
|
||
for ifo in ifos: | ||
ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) | ||
|
||
M_c_min, M_c_max = 10.0, 80.0 | ||
q_min, q_max = 0.125, 1.0 | ||
m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) | ||
m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) | ||
s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) | ||
s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) | ||
dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) | ||
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) | ||
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) | ||
iota_prior = SinePrior(parameter_names=["iota"]) | ||
psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) | ||
ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) | ||
dec_prior = CosinePrior(parameter_names=["dec"]) | ||
|
||
prior = CombinePrior( | ||
[ | ||
m_1_prior, | ||
m_2_prior, | ||
s1z_prior, | ||
s2z_prior, | ||
dL_prior, | ||
t_c_prior, | ||
phase_c_prior, | ||
iota_prior, | ||
psi_prior, | ||
ra_prior, | ||
dec_prior, | ||
] | ||
) | ||
|
||
sample_transforms = [ | ||
ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), | ||
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), | ||
BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), | ||
BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), | ||
BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), | ||
BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=1.0, original_upper_bound=2000.0), | ||
BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), | ||
BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), | ||
BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), | ||
BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), | ||
SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not work with the new predefined transform. Please change accordingly |
||
BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), | ||
BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), | ||
] | ||
|
||
likelihood_transforms = [ | ||
ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), | ||
] | ||
|
||
likelihood = TransientLikelihoodFD( | ||
ifos, | ||
waveform=RippleIMRPhenomD(), | ||
trigger_time=gps, | ||
duration=4, | ||
post_trigger_duration=2, | ||
) | ||
|
||
|
||
mass_matrix = jnp.eye(11) | ||
mass_matrix = mass_matrix.at[1, 1].set(1e-3) | ||
mass_matrix = mass_matrix.at[5, 5].set(1e-3) | ||
local_sampler_arg = {"step_size": mass_matrix * 3e-3} | ||
|
||
Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) | ||
|
||
n_epochs = 30 | ||
n_loop_training = 100 | ||
learning_rate = 1e-4 | ||
|
||
|
||
jim = Jim( | ||
likelihood, | ||
prior, | ||
sample_transforms=sample_transforms, | ||
likelihood_transforms=likelihood_transforms, | ||
n_loop_training=n_loop_training, | ||
n_loop_production=20, | ||
n_local_steps=10, | ||
n_global_steps=1000, | ||
n_chains=500, | ||
n_epochs=n_epochs, | ||
learning_rate=learning_rate, | ||
n_max_examples=30000, | ||
n_flow_samples=100000, | ||
momentum=0.9, | ||
batch_size=30000, | ||
use_global=True, | ||
train_thinning=1, | ||
output_thinning=10, | ||
local_sampler_arg=local_sampler_arg, | ||
strategies=[Adam_optimizer, "default"], | ||
) | ||
|
||
jim.sample(jax.random.PRNGKey(42)) | ||
jim.get_samples() | ||
jim.print_summary() | ||
|
||
|
||
########################################### | ||
########## Visualize the Data ############# | ||
########################################### | ||
import corner | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
production_summary = jim.sampler.get_sampler_state(training=False) | ||
production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T | ||
if jim.sample_transforms: | ||
transformed_chain = jim.add_name(production_chain) | ||
for transform in reversed(jim.sample_transforms): | ||
transformed_chain = transform.backward(transformed_chain) | ||
result = transformed_chain | ||
labels = list(transformed_chain.keys()) | ||
|
||
samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array | ||
transposed_array = samples.T # transpose the array | ||
figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) | ||
plt.savefig("GW1500914_D.jpeg") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not corresponds to Mc_max =80.0, please check for accuracy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is related to the fact that when you sample from
UniformPrior
, it just randomly gives you a value within the range you have defined. Now, without any constraint, the range ofm_1
andm_2
overlaps with each other, som_2
can be bigger thanm_1
when you try to sample on it. Whenm_2
is bigger thanm_1
, it gives strange value ofM_c
andq
.I think its better to build a special prior class for
m_1
andm_2
, in which we add the constraintm_1
>m_2
in sampling.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the value looks fine to me. I can test it with @xuyuon tomorrow.