Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detector and data changes #172

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
85 changes: 61 additions & 24 deletions example/GW150914_IMRPhenomPV2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import optax
import time

import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.jim import Jim
from jimgw.prior import (
CombinePrior,
Expand All @@ -25,8 +25,8 @@
GeocentricArrivalTimeToDetectorArrivalTimeTransform,
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform,
)
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam
from jimgw.single_event import data as jd

jax.config.update("jax_enable_x64", True)

Expand All @@ -37,17 +37,38 @@
total_time_start = time.time()

# first, fetch a 4s segment centered on GW150914
# for the analysis
gps = 1126259462.4
start = gps - 2
end = gps + 2

# fetch 4096s of data to estimate the PSD (to be
# careful we should avoid the on-source segment,
# but we don't do this in this example)
psd_start = gps - 2048
psd_end = gps + 2048

# define frequency integration bounds for the likelihood
# we set fmax to 87.5% of the Nyquist frequency to avoid
# data corrupted by the GWOSC antialiasing filter
# (Note that Data.from_gwosc will pull data sampled at
# 4096 Hz by default)
fmin = 20.0
fmax = 1024.0
fmax = 896.0

ifos = [H1, L1]

H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)
L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2)
for ifo in ifos:
# set analysis data
data = jd.Data.from_gwosc(ifo.name, start, end)
ifo.set_data(data)

# set PSD (Welch estimate)
psd_data = jd.Data.from_gwosc(ifo.name, psd_start, psd_end)
psd_fftlength = data.duration * data.sampling_frequency
ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength))

# define the approximant to use
waveform = RippleIMRPhenomPv2(f_ref=20)

###########################################
Expand Down Expand Up @@ -97,23 +118,39 @@
# Defining Transforms

sample_transforms = [
DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax),
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]),
DistanceToSNRWeightedDistanceTransform(
gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax),
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(
gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(
tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = (["s1_phi"], ["s1_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["s2_phi"], ["s2_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s1_theta"], ["s1_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s2_theta"], ["s2_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s1_mag"], ["s1_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping = (["s2_mag"], ["s2_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping=(["M_c"], [
"M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping=(["q"], ["q_unbounded"]),
original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping=(["s1_phi"], [
"s1_phi_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping=(["s2_phi"], [
"s2_phi_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping=(["iota"], ["iota_unbounded"]),
original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping=(["s1_theta"], [
"s1_theta_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping=(["s2_theta"], [
"s2_theta_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping=(["s1_mag"], [
"s1_mag_unbounded"]), original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping=(["s2_mag"], [
"s2_mag_unbounded"]), original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping=(["phase_det"], [
"phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping=(["psi"], ["psi_unbounded"]),
original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping=(["zenith"], [
"zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping=(["azimuth"], [
"azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
]

likelihood_transforms = [
Expand All @@ -124,7 +161,7 @@


likelihood = TransientLikelihoodFD(
[H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2
[H1, L1], waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps
)


Expand All @@ -133,9 +170,9 @@
# mass_matrix = mass_matrix.at[9, 9].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 1e-3}

Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1)
# Adam_optimizer = optimization_Adam(
# n_steps=3000, learning_rate=0.01, noise_level=1)

import optax

n_epochs = 20
n_loop_training = 100
Expand Down
125 changes: 112 additions & 13 deletions example/notebooks/GW150914.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ python_requires = >=3.9

[options.packages.find]
where=src

[flake8]
ignore = F722
Loading
Loading