From 5ea134782fe0ad3e0b5324ecc82901780168184c Mon Sep 17 00:00:00 2001 From: ThibeauWouters Date: Tue, 7 Nov 2023 14:31:38 +0100 Subject: [PATCH 01/17] attempt to fix heterodyning for GW150914 --- example/GW150914.py | 4 ++-- src/jimgw/likelihood.py | 25 +++++++++++++++++++++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/example/GW150914.py b/example/GW150914.py index 2e73ad58..da523ac4 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -47,8 +47,8 @@ "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec ) -likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) -# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=[prior.xmin, prior.xmax], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) +# likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) +likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=[prior.xmin, prior.xmax], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) mass_matrix = jnp.eye(11) diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 579eca5b..376ab581 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -177,14 +177,24 @@ def __init__( f_max = jnp.max(f_valid) f_min = jnp.min(f_valid) - h_sky = h_sky[jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] - h_sky_low = h_sky_low[jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] - h_sky_center = h_sky_center[jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0]] + # Apply the mask for frequencies to both polarization modes and for all waveforms currently used + for mode in ["p", "c"]: + h_sky[mode] = h_sky[mode][jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] + # h_sky_low[mode] = h_sky_low[mode][jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] + h_sky_low[mode] = h_sky[mode][:-1] # TODO does this work? + h_sky_center[mode] = h_sky_center[mode][jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0]] frequency_original = frequency_original[jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] - self.freq_grid_low = self.freq_grid_low[jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] + freq_grid = freq_grid[jnp.where((freq_grid>=f_min) & (freq_grid<=f_max))[0]] + # self.freq_grid_low = self.freq_grid_low[jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] + self.freq_grid_low = freq_grid[:-1] # TODO override, does this work? self.freq_grid_center = self.freq_grid_center[jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0]] + print("len(self.freq_grid_low)") + print(len(self.freq_grid_low)) + print("len(self.freq_grid_center)") + print(len(self.freq_grid_center)) + align_time = jnp.exp( -1j * 2 @@ -208,6 +218,9 @@ def __init__( ) for detector in self.detectors: + # Also apply the mask of frequencies to the strain data + detector.data = detector.data[jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] + # Get the reference waveforms waveform_ref = ( detector.fd_response(frequency_original, h_sky, self.ref_params) * align_time @@ -257,7 +270,11 @@ def evaluate(self, params: Array, data: dict) -> float: detector.fd_response(frequencies_center, waveform_sky_center, params) * align_time_center ) + r0 = waveform_center / self.waveform_center_ref[detector.name] + print(np.shape(r0)) + print(np.shape(waveform_low)) + print(np.shape(waveform_center)) r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( frequencies_low - frequencies_center ) From 937924bbd6b3273e339b1a16df393d3a626c4f25 Mon Sep 17 00:00:00 2001 From: ThibeauWouters Date: Tue, 7 Nov 2023 15:45:02 +0100 Subject: [PATCH 02/17] fix heterodyning --- example/GW150914.py | 2 +- src/jimgw/detector.py | 2 +- src/jimgw/likelihood.py | 37 +++++++++++++++++-------------------- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/example/GW150914.py b/example/GW150914.py index da523ac4..92c39075 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -59,7 +59,7 @@ jim = Jim( likelihood, prior, - n_loop_training=200, + n_loop_training=100, n_loop_production=10, n_local_steps=150, n_global_steps=150, diff --git a/src/jimgw/detector.py b/src/jimgw/detector.py index 0ff7478a..cf9ad7a9 100644 --- a/src/jimgw/detector.py +++ b/src/jimgw/detector.py @@ -203,7 +203,7 @@ def fd_response(self, frequency: Array, h_sky: dict, params: dict) -> Array: antenna_pattern = self.antenna_pattern(ra, dec, psi, gmst) timeshift = self.delay_from_geocenter(ra, dec, gmst) h_detector = jax.tree_util.tree_map(lambda h, antenna: h * antenna * jnp.exp(-2j * jnp.pi * frequency * timeshift), h_sky, antenna_pattern) - return jnp.sum(jnp.stack(jax.tree_util.tree_leaves(h_detector)),axis=0) + return jnp.sum(jnp.stack(jax.tree_util.tree_leaves(h_detector)), axis=0) def td_response(self, time: Array, h: Array, params: Array) -> Array: """ diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 376ab581..fd9fcf2c 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -150,9 +150,11 @@ def __init__( detectors, waveform, trigger_time, duration, post_trigger_duration ) + # Get the original frequency grid frequency_original = self.detectors[0].frequencies + # Get the grid of the relative binning scheme (contains the final endpoint) and the center points freq_grid, self.freq_grid_center = self.make_binning_scheme( - np.array(frequency_original), n_bins + 1 + np.array(frequency_original), n_bins ) self.freq_grid_low = freq_grid[:-1] @@ -173,28 +175,26 @@ def __init__( h_sky_low = self.waveform(self.freq_grid_low, self.ref_params) h_sky_center = self.waveform(self.freq_grid_center, self.ref_params) + # Get frequency masks to be applied, for both original and heterodyne frequency grid f_valid = frequency_original[jnp.where((jnp.abs(h_sky['p'])+jnp.abs(h_sky['c']))>0)[0]] f_max = jnp.max(f_valid) f_min = jnp.min(f_valid) + + mask_original = jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0] + mask_heterodyne = jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0] # Apply the mask for frequencies to both polarization modes and for all waveforms currently used for mode in ["p", "c"]: - h_sky[mode] = h_sky[mode][jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] - # h_sky_low[mode] = h_sky_low[mode][jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] - h_sky_low[mode] = h_sky[mode][:-1] # TODO does this work? - h_sky_center[mode] = h_sky_center[mode][jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0]] + h_sky[mode] = h_sky[mode][mask_original] + h_sky_low[mode] = h_sky_low[mode][mask_heterodyne] + h_sky_center[mode] = h_sky_center[mode][mask_heterodyne] - frequency_original = frequency_original[jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] + frequency_original = frequency_original[mask_original] freq_grid = freq_grid[jnp.where((freq_grid>=f_min) & (freq_grid<=f_max))[0]] - # self.freq_grid_low = self.freq_grid_low[jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] - self.freq_grid_low = freq_grid[:-1] # TODO override, does this work? - self.freq_grid_center = self.freq_grid_center[jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0]] - - print("len(self.freq_grid_low)") - print(len(self.freq_grid_low)) - print("len(self.freq_grid_center)") - print(len(self.freq_grid_center)) + self.freq_grid_low = self.freq_grid_low[mask_heterodyne] + self.freq_grid_center = self.freq_grid_center[mask_heterodyne] + # Get phase shifts to align time of coalescence align_time = jnp.exp( -1j * 2 @@ -219,7 +219,7 @@ def __init__( for detector in self.detectors: # Also apply the mask of frequencies to the strain data - detector.data = detector.data[jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] + detector.data = detector.data[mask_original] # Get the reference waveforms waveform_ref = ( detector.fd_response(frequency_original, h_sky, self.ref_params) @@ -240,7 +240,7 @@ def __init__( waveform_ref, detector.psd, frequency_original, - self.freq_grid_low, + freq_grid, self.freq_grid_center, ) self.A0_array[detector.name] = A0 @@ -272,9 +272,6 @@ def evaluate(self, params: Array, data: dict) -> float: ) r0 = waveform_center / self.waveform_center_ref[detector.name] - print(np.shape(r0)) - print(np.shape(waveform_low)) - print(np.shape(waveform_center)) r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( frequencies_low - frequencies_center ) @@ -335,7 +332,7 @@ def make_binning_scheme(self, freqs, n_bins, chi=1): phase_diff_array = self.max_phase_diff(freqs, freqs[0], freqs[-1], chi=1) bin_f = interp1d(phase_diff_array, freqs) f_bins = np.array([]) - for i in np.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins): + for i in np.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins + 1): f_bins = np.append(f_bins, bin_f(i)) f_bins_center = (f_bins[:-1] + f_bins[1:]) / 2 return f_bins, f_bins_center From bb4aecf10ca96a10fa2a4a37eef0e7f9828ab46b Mon Sep 17 00:00:00 2001 From: ThibeauWouters Date: Mon, 13 Nov 2023 15:25:38 +0100 Subject: [PATCH 03/17] restore likelihood to official jim implementation --- src/jimgw/likelihood.py | 40 +++++++++++++--------------------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index fd9fcf2c..9ca990ec 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -150,11 +150,9 @@ def __init__( detectors, waveform, trigger_time, duration, post_trigger_duration ) - # Get the original frequency grid frequency_original = self.detectors[0].frequencies - # Get the grid of the relative binning scheme (contains the final endpoint) and the center points freq_grid, self.freq_grid_center = self.make_binning_scheme( - np.array(frequency_original), n_bins + np.array(frequency_original), n_bins + 1 ) self.freq_grid_low = freq_grid[:-1] @@ -175,26 +173,18 @@ def __init__( h_sky_low = self.waveform(self.freq_grid_low, self.ref_params) h_sky_center = self.waveform(self.freq_grid_center, self.ref_params) - # Get frequency masks to be applied, for both original and heterodyne frequency grid f_valid = frequency_original[jnp.where((jnp.abs(h_sky['p'])+jnp.abs(h_sky['c']))>0)[0]] f_max = jnp.max(f_valid) f_min = jnp.min(f_valid) - - mask_original = jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0] - mask_heterodyne = jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0] - - # Apply the mask for frequencies to both polarization modes and for all waveforms currently used - for mode in ["p", "c"]: - h_sky[mode] = h_sky[mode][mask_original] - h_sky_low[mode] = h_sky_low[mode][mask_heterodyne] - h_sky_center[mode] = h_sky_center[mode][mask_heterodyne] - - frequency_original = frequency_original[mask_original] - freq_grid = freq_grid[jnp.where((freq_grid>=f_min) & (freq_grid<=f_max))[0]] - self.freq_grid_low = self.freq_grid_low[mask_heterodyne] - self.freq_grid_center = self.freq_grid_center[mask_heterodyne] - - # Get phase shifts to align time of coalescence + + h_sky = h_sky[jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] + h_sky_low = h_sky_low[jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] + h_sky_center = h_sky_center[jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0]] + + frequency_original = frequency_original[jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] + self.freq_grid_low = self.freq_grid_low[jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] + self.freq_grid_center = self.freq_grid_center[jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0]] + align_time = jnp.exp( -1j * 2 @@ -218,9 +208,6 @@ def __init__( ) for detector in self.detectors: - # Also apply the mask of frequencies to the strain data - detector.data = detector.data[mask_original] - # Get the reference waveforms waveform_ref = ( detector.fd_response(frequency_original, h_sky, self.ref_params) * align_time @@ -240,7 +227,7 @@ def __init__( waveform_ref, detector.psd, frequency_original, - freq_grid, + self.freq_grid_low, self.freq_grid_center, ) self.A0_array[detector.name] = A0 @@ -270,7 +257,6 @@ def evaluate(self, params: Array, data: dict) -> float: detector.fd_response(frequencies_center, waveform_sky_center, params) * align_time_center ) - r0 = waveform_center / self.waveform_center_ref[detector.name] r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( frequencies_low - frequencies_center @@ -332,7 +318,7 @@ def make_binning_scheme(self, freqs, n_bins, chi=1): phase_diff_array = self.max_phase_diff(freqs, freqs[0], freqs[-1], chi=1) bin_f = interp1d(phase_diff_array, freqs) f_bins = np.array([]) - for i in np.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins + 1): + for i in np.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins): f_bins = np.append(f_bins, bin_f(i)) f_bins_center = (f_bins[:-1] + f_bins[1:]) / 2 return f_bins, f_bins_center @@ -395,4 +381,4 @@ def maximize_likelihood( optimizer = EvolutionaryOptimizer(len(bounds), verbose=True) state = optimizer.optimize(y, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] - return prior.add_name(best_fit, transform_name=True, transform_value=True) + return prior.add_name(best_fit, transform_name=True, transform_value=True) \ No newline at end of file From dcd7c56756a62203802811548eb7b09b3254461f Mon Sep 17 00:00:00 2001 From: ThibeauWouters Date: Tue, 14 Nov 2023 09:31:52 +0100 Subject: [PATCH 04/17] copy pasted relative binning fix code --- src/jimgw/likelihood.py | 61 ++++++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 9ca990ec..e538f9ba 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -144,15 +144,17 @@ def __init__( duration: float = 4, post_trigger_duration: float = 2, n_walkers: int = 100, - n_loops: int = 2000, + n_loops: int = 20, ) -> None: super().__init__( detectors, waveform, trigger_time, duration, post_trigger_duration ) + # Get the original frequency grid frequency_original = self.detectors[0].frequencies + # Get the grid of the relative binning scheme (contains the final endpoint) and the center points freq_grid, self.freq_grid_center = self.make_binning_scheme( - np.array(frequency_original), n_bins + 1 + np.array(frequency_original), n_bins ) self.freq_grid_low = freq_grid[:-1] @@ -173,18 +175,47 @@ def __init__( h_sky_low = self.waveform(self.freq_grid_low, self.ref_params) h_sky_center = self.waveform(self.freq_grid_center, self.ref_params) + # Get frequency masks to be applied, for both original and heterodyne frequency grid f_valid = frequency_original[jnp.where((jnp.abs(h_sky['p'])+jnp.abs(h_sky['c']))>0)[0]] f_max = jnp.max(f_valid) f_min = jnp.min(f_valid) - - h_sky = h_sky[jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] - h_sky_low = h_sky_low[jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] - h_sky_center = h_sky_center[jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0]] - - frequency_original = frequency_original[jnp.where((frequency_original>=f_min) & (frequency_original<=f_max))[0]] - self.freq_grid_low = self.freq_grid_low[jnp.where((self.freq_grid_low>=f_min) & (self.freq_grid_low<=f_max))[0]] - self.freq_grid_center = self.freq_grid_center[jnp.where((self.freq_grid_center>=f_min) & (self.freq_grid_center<=f_max))[0]] - + + # TODO replace this + def get_mask(f: Array, f_min: float, f_max: float) -> Array: + """Slice an array f by containing all elements in f that are greater or equal to f_min, and all elements smaller than or equal + to f_max, and the element just right after that. + + Args: + f (Array): Frequency array to be sliced + f_min (float): Min frequency to be included + f_max (float): Max frequency to be included + + Returns: + Array: Sliced array. + """ + mask = np.array([False for value in f]) + index_f_min = np.argwhere(f >= f_min).flatten()[0] + index_f_max = np.argwhere(f <= f_max).flatten()[-1] + index_f_max = min(index_f_max + 1, len(f) - 1) + mask[index_f_min:index_f_max + 1] = True + return mask + + mask_original = get_mask(frequency_original, f_min, f_max) + mask_heterodyne_low = get_mask(self.freq_grid_low, f_min, f_max) + mask_heterodyne_center = get_mask(self.freq_grid_center, f_min, f_max) + + # Apply the mask for frequencies to both polarization modes and for all waveforms currently used + for mode in ["p", "c"]: + h_sky[mode] = h_sky[mode][mask_original] + h_sky_low[mode] = h_sky_low[mode][mask_heterodyne_low] + h_sky_center[mode] = h_sky_center[mode][mask_heterodyne_center] + + frequency_original = frequency_original[mask_original] + freq_grid = freq_grid[get_mask(freq_grid, f_min, f_max)] + self.freq_grid_low = self.freq_grid_low[mask_heterodyne_low] + self.freq_grid_center = self.freq_grid_center[mask_heterodyne_center] + + # Get phase shifts to align time of coalescence align_time = jnp.exp( -1j * 2 @@ -208,6 +239,9 @@ def __init__( ) for detector in self.detectors: + # Also apply the mask of frequencies to the strain data + detector.data = detector.data[mask_original] + # Get the reference waveforms waveform_ref = ( detector.fd_response(frequency_original, h_sky, self.ref_params) * align_time @@ -227,7 +261,7 @@ def __init__( waveform_ref, detector.psd, frequency_original, - self.freq_grid_low, + freq_grid, self.freq_grid_center, ) self.A0_array[detector.name] = A0 @@ -257,6 +291,7 @@ def evaluate(self, params: Array, data: dict) -> float: detector.fd_response(frequencies_center, waveform_sky_center, params) * align_time_center ) + r0 = waveform_center / self.waveform_center_ref[detector.name] r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( frequencies_low - frequencies_center @@ -318,7 +353,7 @@ def make_binning_scheme(self, freqs, n_bins, chi=1): phase_diff_array = self.max_phase_diff(freqs, freqs[0], freqs[-1], chi=1) bin_f = interp1d(phase_diff_array, freqs) f_bins = np.array([]) - for i in np.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins): + for i in np.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins + 1): f_bins = np.append(f_bins, bin_f(i)) f_bins_center = (f_bins[:-1] + f_bins[1:]) / 2 return f_bins, f_bins_center From 3c38bab0f986d9c0c6410e47e0ee3372b0e56592 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 1 Dec 2023 15:33:09 -0500 Subject: [PATCH 05/17] Minor phrasing fix --- src/jimgw/detector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/detector.py b/src/jimgw/detector.py index cf9ad7a9..81685f8f 100644 --- a/src/jimgw/detector.py +++ b/src/jimgw/detector.py @@ -190,7 +190,7 @@ def load_data(self, trigger_time:float, psd_data_td = TimeSeries.fetch_open_data(self.name, start_psd, end_psd, cache=True) psd = psd_data_td.psd(fftlength=segment_length).value # TODO: Check whether this is sright. - print("Finished generating data.") + print("Finished loading data.") self.frequencies = freq[(freq>f_min)&(freqf_min)&(freq Date: Fri, 1 Dec 2023 15:33:52 -0500 Subject: [PATCH 06/17] Do not use heterodyne by default --- example/GW150914.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/example/GW150914.py b/example/GW150914.py index 92c39075..949fee63 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -47,9 +47,7 @@ "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec ) -# likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) -likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=[prior.xmin, prior.xmax], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) - +likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) mass_matrix = jnp.eye(11) mass_matrix = mass_matrix.at[1, 1].set(1e-3) From 866dd41825c29bb93e4191c90a8811734dd2ce34 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 1 Dec 2023 17:12:53 -0500 Subject: [PATCH 07/17] Fixing heterodyne. It is not working as intended --- example/GW170817.py | 102 ++++++++++++++++++++++++++++++++++++++++ src/jimgw/detector.py | 11 +++-- src/jimgw/likelihood.py | 4 +- 3 files changed, 111 insertions(+), 6 deletions(-) create mode 100644 example/GW170817.py diff --git a/example/GW170817.py b/example/GW170817.py new file mode 100644 index 00000000..8f01cc23 --- /dev/null +++ b/example/GW170817.py @@ -0,0 +1,102 @@ +import time +from jimgw.jim import Jim +from jimgw.detector import H1, L1, V1 +from jimgw.likelihood import HeterodynedTransientLikelihoodFD +from jimgw.waveform import RippleIMRPhenomD +from jimgw.prior import Uniform +from gwosc.datasets import event_gps +import jax.numpy as jnp +import jax + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +gps = event_gps("GW170817") +duration = 128 +post_trigger_duration = 32 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = ["H1", "L1"]#, "V1"] + +H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=4*duration, tukey_alpha=0.05, gwpy_kwargs={"version": 2, "cache": False}) +L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=4*duration, tukey_alpha=0.05, gwpy_kwargs={"version": 2, "cache": False}) +# V1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.05) + +prior = Uniform( + xmin=[1.18, 0.125, -0.3, -0.3, 1., -0.1, 0.0, -1, 0.0, 0.0, -1.0], + xmax=[1.21, 1.0, 0.3, 0.3, 75., 0.1, 2 * jnp.pi, 1.0, jnp.pi, 2 * jnp.pi, 1.0], + naming=[ + "M_c", + "q", + "s1_z", + "s2_z", + "d_L", + "t_c", + "phase_c", + "cos_iota", + "psi", + "ra", + "sin_dec", + ], + transforms={ + "q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2), + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ), + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ), + }, # sin and arcsin are periodize cos_iota and sin_dec +) + +likelihood = HeterodynedTransientLikelihoodFD( + [H1], + prior=prior, + bounds=[prior.xmin, prior.xmax], + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=duration, + post_trigger_duration=post_trigger_duration, + n_loops=1000 +) + +# mass_matrix = jnp.eye(11) +# mass_matrix = mass_matrix.at[1, 1].set(1e-3) +# mass_matrix = mass_matrix.at[5, 5].set(1e-3) +# local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +# jim = Jim( +# likelihood, +# prior, +# n_loop_training=100, +# n_loop_production=10, +# n_local_steps=150, +# n_global_steps=150, +# n_chains=500, +# n_epochs=50, +# learning_rate=0.001, +# max_samples=45000, +# momentum=0.9, +# batch_size=50000, +# use_global=True, +# keep_quantile=0.0, +# train_thinning=1, +# output_thinning=10, +# local_sampler_arg=local_sampler_arg, +# ) + +# jim.sample(jax.random.PRNGKey(42)) diff --git a/src/jimgw/detector.py b/src/jimgw/detector.py index 81685f8f..10faf04d 100644 --- a/src/jimgw/detector.py +++ b/src/jimgw/detector.py @@ -152,7 +152,8 @@ def load_data(self, trigger_time:float, f_min: float, f_max: float, psd_pad: int = 16, - tukey_alpha: float = 0.2) -> None: + tukey_alpha: float = 0.2, + gwpy_kwargs: dict = {"cache":True}) -> None: """ Load data from the detector. @@ -176,18 +177,18 @@ def load_data(self, trigger_time:float, """ print("Fetching data from {}...".format(self.name)) - data_td = TimeSeries.fetch_open_data(self.name, trigger_time - gps_start_pad, trigger_time + gps_end_pad, cache=True) + data_td = TimeSeries.fetch_open_data(self.name, trigger_time - gps_start_pad, trigger_time + gps_end_pad, **gwpy_kwargs) segment_length = data_td.duration.value n = len(data_td) delta_t = data_td.dt.value data = jnp.fft.rfft(jnp.array(data_td.value)*tukey(n, tukey_alpha))*delta_t freq = jnp.fft.rfftfreq(n, delta_t) # TODO: Check if this is the right way to fetch PSD - start_psd = int(trigger_time) - gps_start_pad - psd_pad # What does Int do here? - end_psd = int(trigger_time) + gps_end_pad + psd_pad + start_psd = int(trigger_time) - gps_start_pad - 2*psd_pad # What does Int do here? + end_psd = int(trigger_time) - gps_start_pad - psd_pad print("Fetching PSD data...") - psd_data_td = TimeSeries.fetch_open_data(self.name, start_psd, end_psd, cache=True) + psd_data_td = TimeSeries.fetch_open_data(self.name, start_psd, end_psd, **gwpy_kwargs) psd = psd_data_td.psd(fftlength=segment_length).value # TODO: Check whether this is sright. print("Finished loading data.") diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index e538f9ba..a89ac184 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -144,7 +144,7 @@ def __init__( duration: float = 4, post_trigger_duration: float = 2, n_walkers: int = 100, - n_loops: int = 20, + n_loops: int = 200, ) -> None: super().__init__( detectors, waveform, trigger_time, duration, post_trigger_duration @@ -162,6 +162,8 @@ def __init__( bounds=bounds, prior=prior, set_nwalkers=n_walkers, n_loops=n_loops ) + print("Constructing reference waveforms..") + self.ref_params["gmst"] = self.gmst self.waveform_low_ref = {} From 4a20cd4bb62998945daf07081cb40920da6c06ef Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 2 Dec 2023 14:26:52 -0500 Subject: [PATCH 08/17] Add frequnecy check to likelihood --- src/jimgw/likelihood.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index a89ac184..41350044 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -1,14 +1,16 @@ from abc import ABC, abstractmethod -from jaxtyping import Array, Float -from jimgw.waveform import Waveform -from jimgw.detector import Detector + +import jax import jax.numpy as jnp -from astropy.time import Time import numpy as np -from scipy.interpolate import interp1d -import jax +from astropy.time import Time from flowMC.utils.EvolutionaryOptimizer import EvolutionaryOptimizer +from jaxtyping import Array, Float +from scipy.interpolate import interp1d + +from jimgw.detector import Detector from jimgw.prior import Prior +from jimgw.waveform import Waveform class LikelihoodBase(ABC): @@ -151,6 +153,16 @@ def __init__( ) # Get the original frequency grid + + assert jnp.all( + jnp.array( + [ + (self.detectors[0].frequencies == detector.frequencies).all() + for detector in self.detectors + ] + ) + ), "The detectors must have the same frequency grid" + frequency_original = self.detectors[0].frequencies # Get the grid of the relative binning scheme (contains the final endpoint) and the center points freq_grid, self.freq_grid_center = self.make_binning_scheme( @@ -178,10 +190,12 @@ def __init__( h_sky_center = self.waveform(self.freq_grid_center, self.ref_params) # Get frequency masks to be applied, for both original and heterodyne frequency grid - f_valid = frequency_original[jnp.where((jnp.abs(h_sky['p'])+jnp.abs(h_sky['c']))>0)[0]] + f_valid = frequency_original[ + jnp.where((jnp.abs(h_sky["p"]) + jnp.abs(h_sky["c"])) > 0)[0] + ] f_max = jnp.max(f_valid) f_min = jnp.min(f_valid) - + # TODO replace this def get_mask(f: Array, f_min: float, f_max: float) -> Array: """Slice an array f by containing all elements in f that are greater or equal to f_min, and all elements smaller than or equal @@ -199,9 +213,9 @@ def get_mask(f: Array, f_min: float, f_max: float) -> Array: index_f_min = np.argwhere(f >= f_min).flatten()[0] index_f_max = np.argwhere(f <= f_max).flatten()[-1] index_f_max = min(index_f_max + 1, len(f) - 1) - mask[index_f_min:index_f_max + 1] = True + mask[index_f_min : index_f_max + 1] = True return mask - + mask_original = get_mask(frequency_original, f_min, f_max) mask_heterodyne_low = get_mask(self.freq_grid_low, f_min, f_max) mask_heterodyne_center = get_mask(self.freq_grid_center, f_min, f_max) @@ -293,7 +307,7 @@ def evaluate(self, params: Array, data: dict) -> float: detector.fd_response(frequencies_center, waveform_sky_center, params) * align_time_center ) - + r0 = waveform_center / self.waveform_center_ref[detector.name] r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( frequencies_low - frequencies_center @@ -418,4 +432,4 @@ def maximize_likelihood( optimizer = EvolutionaryOptimizer(len(bounds), verbose=True) state = optimizer.optimize(y, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] - return prior.add_name(best_fit, transform_name=True, transform_value=True) \ No newline at end of file + return prior.add_name(best_fit, transform_name=True, transform_value=True) From d7c3ac7ad019fe4cfb83413f3047eb1fa14ea99f Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 2 Dec 2023 14:39:41 -0500 Subject: [PATCH 09/17] update likelihood. Turning off precommit now. --- .pre-commit-config.yaml | 6 +++--- src/jimgw/likelihood.py | 48 +++++++++++++++++++++++++---------------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 213d97c8..b5595520 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,17 +4,17 @@ repos: hooks: - id: black - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.0.290' + rev: 'v0.1.6' hooks: - id: ruff args: ["--fix"] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.327 + rev: v1.1.338 hooks: - id: pyright additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, tensorflow, tf2onnx, typing_extensions] - repo: https://github.com/nbQA-dev/nbQA - rev: 1.7.0 + rev: 1.7.1 hooks: - id: nbqa-black additional_dependencies: [ipython==8.12, black] diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 41350044..5e0a021d 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -16,11 +16,13 @@ class LikelihoodBase(ABC): """ Base class for likelihoods. - Note that this likelihood class should work for a some what general class of problems. - In light of that, this class would be some what abstract, but the idea behind it is this - handles two main components of a likelihood: the data and the model. - - It should be able to take the data and model and evaluate the likelihood for a given set of parameters. + Note that this likelihood class should work + for a some what general class of problems. + In light of that, this class would be some what abstract, + but the idea behind it is this handles two main components of a likelihood: + the data and the model. + It should be able to take the data and model and evaluate the likelihood for + a given set of parameters. """ @@ -47,7 +49,6 @@ def evaluate(self, params) -> float: class TransientLikelihoodFD(LikelihoodBase): - detectors: list[Detector] waveform: Waveform @@ -86,7 +87,9 @@ def ifos(self): def evaluate( self, params: Array, data: dict - ) -> float: # TODO: Test whether we need to pass data in or with class changes is fine. + ) -> ( + float + ): # TODO: Test whether we need to pass data in or with class changes is fine. """ Evaluate the likelihood for a given set of parameters. """ @@ -119,7 +122,6 @@ def evaluate( class HeterodynedTransientLikelihoodFD(TransientLikelihoodFD): - n_bins: int # Number of bins to use for the likelihood ref_params: dict # Reference parameters for the likelihood freq_grid_low: Array # Heterodyned frequency grid @@ -164,7 +166,8 @@ def __init__( ), "The detectors must have the same frequency grid" frequency_original = self.detectors[0].frequencies - # Get the grid of the relative binning scheme (contains the final endpoint) and the center points + # Get the grid of the relative binning scheme (contains the final endpoint) + # and the center points freq_grid, self.freq_grid_center = self.make_binning_scheme( np.array(frequency_original), n_bins ) @@ -189,7 +192,8 @@ def __init__( h_sky_low = self.waveform(self.freq_grid_low, self.ref_params) h_sky_center = self.waveform(self.freq_grid_center, self.ref_params) - # Get frequency masks to be applied, for both original and heterodyne frequency grid + # Get frequency masks to be applied, for both original + # and heterodyne frequency grid f_valid = frequency_original[ jnp.where((jnp.abs(h_sky["p"]) + jnp.abs(h_sky["c"])) > 0)[0] ] @@ -198,7 +202,8 @@ def __init__( # TODO replace this def get_mask(f: Array, f_min: float, f_max: float) -> Array: - """Slice an array f by containing all elements in f that are greater or equal to f_min, and all elements smaller than or equal + """Slice an array f by containing all elements in f + that are greater or equal to f_min, and all elements smaller than or equal to f_max, and the element just right after that. Args: @@ -220,7 +225,8 @@ def get_mask(f: Array, f_min: float, f_max: float) -> Array: mask_heterodyne_low = get_mask(self.freq_grid_low, f_min, f_max) mask_heterodyne_center = get_mask(self.freq_grid_center, f_min, f_max) - # Apply the mask for frequencies to both polarization modes and for all waveforms currently used + # Apply the mask for frequencies to both polarization modes + # and for all waveforms currently used for mode in ["p", "c"]: h_sky[mode] = h_sky[mode][mask_original] h_sky_low[mode] = h_sky_low[mode][mask_heterodyne_low] @@ -326,7 +332,9 @@ def evaluate(self, params: Array, data: dict) -> float: def evaluate_original( self, params: Array, data: dict - ) -> float: # TODO: Test whether we need to pass data in or with class changes is fine. + ) -> ( + float + ): # TODO: Test whether we need to pass data in or with class changes is fine. """ Evaluate the likelihood for a given set of parameters. """ @@ -365,8 +373,8 @@ def max_phase_diff(f, f_low, f_high, chi=1): f_star[gamma >= 0] = f_high return 2 * np.pi * chi * np.sum((f / f_star) ** gamma * np.sign(gamma), axis=1) - def make_binning_scheme(self, freqs, n_bins, chi=1): - phase_diff_array = self.max_phase_diff(freqs, freqs[0], freqs[-1], chi=1) + def make_binning_scheme(self, freqs: Float[Array, "dim"], n_bins, chi=1): + phase_diff_array = self.max_phase_diff(freqs, freqs[0], freqs[-1], chi=chi) bin_f = interp1d(phase_diff_array, freqs) f_bins = np.array([]) for i in np.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins + 1): @@ -423,13 +431,15 @@ def maximize_likelihood( bounds = jnp.array(bounds).T set_nwalkers = set_nwalkers - y = lambda x: -self.evaluate_original( - prior.add_name(x, transform_name=True, transform_value=True), None - ) + def y(x): + return -self.evaluate_original( + prior.add_name(x, transform_name=True, transform_value=True), None + ) + y = jax.jit(jax.vmap(y)) print("Starting the optimizer") optimizer = EvolutionaryOptimizer(len(bounds), verbose=True) - state = optimizer.optimize(y, bounds, n_loops=n_loops) + optimizer.optimize(y, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] return prior.add_name(best_fit, transform_name=True, transform_value=True) From e6d1ce7b2bc7125cbb2df5d5add5c2ba6a1fcc70 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 2 Dec 2023 14:40:06 -0500 Subject: [PATCH 10/17] Add residual typing fix --- src/jimgw/detector.py | 276 ++++++++++++++++++++++++---------------- src/jimgw/likelihood.py | 2 + 2 files changed, 165 insertions(+), 113 deletions(-) diff --git a/src/jimgw/detector.py b/src/jimgw/detector.py index 10faf04d..dd9cbcd3 100644 --- a/src/jimgw/detector.py +++ b/src/jimgw/detector.py @@ -1,26 +1,27 @@ -import jax.numpy as jnp -from jimgw.constants import * -from jimgw.wave import Polarization -from scipy.signal.windows import tukey from abc import ABC, abstractmethod -import equinox as eqx -from jaxtyping import Array, PRNGKeyArray + import jax -from gwpy.timeseries import TimeSeries -from typing import Callable -import requests +import jax.numpy as jnp import numpy as np +import requests +from gwpy.timeseries import TimeSeries +from jaxtyping import Array, PRNGKeyArray from scipy.interpolate import interp1d +from scipy.signal.windows import tukey -DEG_TO_RAD = jnp.pi/180 +from jimgw.constants import * +from jimgw.wave import Polarization + +DEG_TO_RAD = jnp.pi / 180 # TODO: Need to expand this list. Currently it is only O3. -psd_file_dict= { +psd_file_dict = { "H1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-H1-C01_CLEAN_SUB60HZ-1251752040.0_sensitivity_strain_asd.txt", "L1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-L1-C01_CLEAN_SUB60HZ-1240573680.0_sensitivity_strain_asd.txt", "V1": "https://dcc.ligo.org/public/0169/P2000251/001/O3-V1_sensitivity_strain_asd.txt", } + def np2(x): """ Returns the next power of two as big as or larger than x.""" @@ -29,13 +30,14 @@ def np2(x): p = p << 1 return p + class Detector(ABC): - """ + """ Base class for all detectors. """ - name: str + name: str @abstractmethod def load_data(self, data): @@ -44,20 +46,23 @@ def load_data(self, data): @abstractmethod def fd_response(self, frequency: Array, h: Array, params: dict) -> Array: """ - Modulate the waveform in the sky frame by the detector response in the frequency domain.""" + Modulate the waveform in the sky frame by the detector response + in the frequency domain.""" pass @abstractmethod def td_response(self, time: Array, h: Array, params: dict) -> Array: """ - Modulate the waveform in the sky frame by the detector response in the time domain.""" + Modulate the waveform in the sky frame by the detector response + in the time domain.""" pass - + + class GroundBased2G(Detector): polarization_mode: list[Polarization] frequencies: Array = None - data : Array = None + data: Array = None psd: Array = None latitude: float = 0 @@ -71,14 +76,14 @@ class GroundBased2G(Detector): def __init__(self, name: str, **kwargs) -> None: self.name = name - self.latitude = kwargs.get('latitude', 0) - self.longitude = kwargs.get('longitude', 0) - self.elevation = kwargs.get('elevation', 0) - self.xarm_azimuth = kwargs.get('xarm_azimuth', 0) - self.yarm_azimuth = kwargs.get('yarm_azimuth', 0) - self.xarm_tilt = kwargs.get('xarm_tilt', 0) - self.yarm_tilt = kwargs.get('yarm_tilt', 0) - modes = kwargs.get('mode', 'pc') + self.latitude = kwargs.get("latitude", 0) + self.longitude = kwargs.get("longitude", 0) + self.elevation = kwargs.get("elevation", 0) + self.xarm_azimuth = kwargs.get("xarm_azimuth", 0) + self.yarm_azimuth = kwargs.get("yarm_azimuth", 0) + self.xarm_tilt = kwargs.get("xarm_tilt", 0) + self.yarm_tilt = kwargs.get("yarm_tilt", 0) + modes = kwargs.get("mode", "pc") self.polarization_mode = [Polarization(m) for m in modes] @@ -99,33 +104,42 @@ def _get_arm(lat, lon, tilt, azimuth): arm azimuth in rad. """ e_lon = jnp.array([-jnp.sin(lon), jnp.cos(lon), 0]) - e_lat = jnp.array([-jnp.sin(lat) * jnp.cos(lon), - -jnp.sin(lat) * jnp.sin(lon), jnp.cos(lat)]) - e_h = jnp.array([jnp.cos(lat) * jnp.cos(lon), - jnp.cos(lat) * jnp.sin(lon), jnp.sin(lat)]) - - return (jnp.cos(tilt) * jnp.cos(azimuth) * e_lon + - jnp.cos(tilt) * jnp.sin(azimuth) * e_lat + - jnp.sin(tilt) * e_h) + e_lat = jnp.array( + [-jnp.sin(lat) * jnp.cos(lon), -jnp.sin(lat) * jnp.sin(lon), jnp.cos(lat)] + ) + e_h = jnp.array( + [jnp.cos(lat) * jnp.cos(lon), jnp.cos(lat) * jnp.sin(lon), jnp.sin(lat)] + ) + + return ( + jnp.cos(tilt) * jnp.cos(azimuth) * e_lon + + jnp.cos(tilt) * jnp.sin(azimuth) * e_lat + + jnp.sin(tilt) * e_h + ) @property def arms(self): """ Detector arm vectors (x, y). """ - x = self._get_arm(self.latitude, self.longitude, self.xarm_tilt, self.xarm_azimuth) - y = self._get_arm(self.latitude, self.longitude, self.yarm_tilt, self.yarm_azimuth) + x = self._get_arm( + self.latitude, self.longitude, self.xarm_tilt, self.xarm_azimuth + ) + y = self._get_arm( + self.latitude, self.longitude, self.yarm_tilt, self.yarm_azimuth + ) return x, y - + @property def tensor(self): """ Detector tensor defining the strain measurement. """ - #TODO: this could easily be generalized for other detector geometries + # TODO: this could easily be generalized for other detector geometries arm1, arm2 = self.arms - return 0.5 * (jnp.einsum('i,j->ij', arm1, arm1) - - jnp.einsum('i,j->ij', arm2, arm2)) + return 0.5 * ( + jnp.einsum("i,j->ij", arm1, arm1) - jnp.einsum("i,j->ij", arm2, arm2) + ) @property def vertex(self): @@ -140,20 +154,25 @@ def vertex(self): h = self.elevation major, minor = EARTH_SEMI_MAJOR_AXIS, EARTH_SEMI_MINOR_AXIS # compute vertex location - r = major**2*(major**2*jnp.cos(lat)**2 + minor**2*jnp.sin(lat)**2)**(-0.5) + r = major**2 * ( + major**2 * jnp.cos(lat) ** 2 + minor**2 * jnp.sin(lat) ** 2 + ) ** (-0.5) x = (r + h) * jnp.cos(lat) * jnp.cos(lon) y = (r + h) * jnp.cos(lat) * jnp.sin(lon) - z = ((minor / major)**2 * r + h)*jnp.sin(lat) + z = ((minor / major) ** 2 * r + h) * jnp.sin(lat) return jnp.array([x, y, z]) - def load_data(self, trigger_time:float, - gps_start_pad: int, - gps_end_pad: int, - f_min: float, - f_max: float, - psd_pad: int = 16, - tukey_alpha: float = 0.2, - gwpy_kwargs: dict = {"cache":True}) -> None: + def load_data( + self, + trigger_time: float, + gps_start_pad: int, + gps_end_pad: int, + f_min: float, + f_max: float, + psd_pad: int = 16, + tukey_alpha: float = 0.2, + gwpy_kwargs: dict = {"cache": True}, + ) -> None: """ Load data from the detector. @@ -177,33 +196,50 @@ def load_data(self, trigger_time:float, """ print("Fetching data from {}...".format(self.name)) - data_td = TimeSeries.fetch_open_data(self.name, trigger_time - gps_start_pad, trigger_time + gps_end_pad, **gwpy_kwargs) + data_td = TimeSeries.fetch_open_data( + self.name, + trigger_time - gps_start_pad, + trigger_time + gps_end_pad, + **gwpy_kwargs + ) segment_length = data_td.duration.value n = len(data_td) delta_t = data_td.dt.value - data = jnp.fft.rfft(jnp.array(data_td.value)*tukey(n, tukey_alpha))*delta_t + data = jnp.fft.rfft(jnp.array(data_td.value) * tukey(n, tukey_alpha)) * delta_t freq = jnp.fft.rfftfreq(n, delta_t) # TODO: Check if this is the right way to fetch PSD - start_psd = int(trigger_time) - gps_start_pad - 2*psd_pad # What does Int do here? + start_psd = ( + int(trigger_time) - gps_start_pad - 2 * psd_pad + ) # What does Int do here? end_psd = int(trigger_time) - gps_start_pad - psd_pad print("Fetching PSD data...") - psd_data_td = TimeSeries.fetch_open_data(self.name, start_psd, end_psd, **gwpy_kwargs) - psd = psd_data_td.psd(fftlength=segment_length).value # TODO: Check whether this is sright. + psd_data_td = TimeSeries.fetch_open_data( + self.name, start_psd, end_psd, **gwpy_kwargs + ) + psd = psd_data_td.psd( + fftlength=segment_length + ).value # TODO: Check whether this is sright. print("Finished loading data.") - self.frequencies = freq[(freq>f_min)&(freqf_min)&(freqf_min)&(freq f_min) & (freq < f_max)] + self.data = data[(freq > f_min) & (freq < f_max)] + self.psd = psd[(freq > f_min) & (freq < f_max)] def fd_response(self, frequency: Array, h_sky: dict, params: dict) -> Array: """ Modulate the waveform in the sky frame by the detector response in the frequency domain.""" - ra, dec, psi, gmst = params['ra'], params['dec'], params['psi'], params['gmst'] + ra, dec, psi, gmst = params["ra"], params["dec"], params["psi"], params["gmst"] antenna_pattern = self.antenna_pattern(ra, dec, psi, gmst) timeshift = self.delay_from_geocenter(ra, dec, gmst) - h_detector = jax.tree_util.tree_map(lambda h, antenna: h * antenna * jnp.exp(-2j * jnp.pi * frequency * timeshift), h_sky, antenna_pattern) + h_detector = jax.tree_util.tree_map( + lambda h, antenna: h + * antenna + * jnp.exp(-2j * jnp.pi * frequency * timeshift), + h_sky, + antenna_pattern, + ) return jnp.sum(jnp.stack(jax.tree_util.tree_leaves(h_detector)), axis=0) def td_response(self, time: Array, h: Array, params: Array) -> Array: @@ -211,10 +247,8 @@ def td_response(self, time: Array, h: Array, params: Array) -> Array: Modulate the waveform in the sky frame by the detector response in the time domain.""" pass - - def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float: - """ + """ Calculate time delay between two detectors in geocentric coordinates based on XLALArrivaTimeDiff in TimeDelay.c @@ -237,12 +271,16 @@ def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float: gmst = jnp.mod(gmst, 2 * jnp.pi) phi = ra - gmst theta = jnp.pi / 2 - dec - omega = jnp.array([jnp.sin(theta)*jnp.cos(phi), - jnp.sin(theta)*jnp.sin(phi), - jnp.cos(theta)]) + omega = jnp.array( + [ + jnp.sin(theta) * jnp.cos(phi), + jnp.sin(theta) * jnp.sin(phi), + jnp.cos(theta), + ] + ) return jnp.dot(omega, delta_d) / C_SI - def antenna_pattern(self, ra:float, dec:float, psi:float, gmst:float) -> dict: + def antenna_pattern(self, ra: float, dec: float, psi: float, gmst: float) -> dict: """ Computes {name} antenna patterns for {modes} polarizations at the specified sky location, orientation and GMST. @@ -263,78 +301,90 @@ def antenna_pattern(self, ra:float, dec:float, psi:float, gmst:float) -> dict: Greenwich mean sidereal time (GMST) in radians. modes : str string of polarizations to include, defaults to tensor modes: 'pc'. - + Returns ------- result : list antenna pattern values for {modes}. - """ + """ detector_tensor = self.tensor antenna_patterns = {} for polarization in self.polarization_mode: wave_tensor = polarization.tensor_from_sky(ra, dec, psi, gmst) - antenna_patterns[polarization.name] = jnp.einsum('ij,ij->', detector_tensor, wave_tensor) + antenna_patterns[polarization.name] = jnp.einsum( + "ij,ij->", detector_tensor, wave_tensor + ) return antenna_patterns - def inject_signal(self, - key: PRNGKeyArray, - freqs: Array, - h_sky: dict, - params: dict, - psd_file: str = None) -> None: - """ - """ + def inject_signal( + self, + key: PRNGKeyArray, + freqs: Array, + h_sky: dict, + params: dict, + psd_file: str = None, + ) -> None: + """ """ self.frequencies = freqs self.psd = self.load_psd(freqs, psd_file) key, subkey = jax.random.split(key, 2) var = self.psd / (4 * (freqs[1] - freqs[0])) - noise_real = jax.random.normal(key, shape=freqs.shape)*jnp.sqrt(var) - noise_imag = jax.random.normal(subkey, shape=freqs.shape)*jnp.sqrt(var) - align_time = jnp.exp(-1j*2*jnp.pi*freqs*(params['epoch']+params['t_c'])) + noise_real = jax.random.normal(key, shape=freqs.shape) * jnp.sqrt(var) + noise_imag = jax.random.normal(subkey, shape=freqs.shape) * jnp.sqrt(var) + align_time = jnp.exp( + -1j * 2 * jnp.pi * freqs * (params["epoch"] + params["t_c"]) + ) signal = self.fd_response(freqs, h_sky, params) * align_time - self.data = signal + noise_real + 1j*noise_imag + self.data = signal + noise_real + 1j * noise_imag def load_psd(self, freqs: Array, psd_file: str = None) -> None: if psd_file is None: - print("Grabbing GWTC-2 PSD for "+self.name) + print("Grabbing GWTC-2 PSD for " + self.name) url = psd_file_dict[self.name] data = requests.get(url) - open(self.name+".txt", "wb").write(data.content) - f, asd_vals = np.loadtxt(self.name+".txt", unpack=True) + open(self.name + ".txt", "wb").write(data.content) + f, asd_vals = np.loadtxt(self.name + ".txt", unpack=True) else: f, asd_vals = np.loadtxt(psd_file, unpack=True) psd_vals = asd_vals**2 psd = interp1d(f, psd_vals, fill_value=(psd_vals[0], psd_vals[-1]))(freqs) return psd -H1 = GroundBased2G('H1', -latitude = (46 + 27. / 60 + 18.528 / 3600) * DEG_TO_RAD, -longitude = -(119 + 24. / 60 + 27.5657 / 3600) * DEG_TO_RAD, -xarm_azimuth = 125.9994 * DEG_TO_RAD, -yarm_azimuth = 215.9994 * DEG_TO_RAD, -xarm_tilt = -6.195e-4, -yarm_tilt = 1.25e-5, -elevation = 142.554, -mode='pc') - -L1 = GroundBased2G('L1', -latitude = (30 + 33. / 60 + 46.4196 / 3600) * DEG_TO_RAD, -longitude = -(90 + 46. / 60 + 27.2654 / 3600) * DEG_TO_RAD, -xarm_azimuth = 197.7165 * DEG_TO_RAD, -yarm_azimuth = 287.7165 * DEG_TO_RAD, -xarm_tilt = 0 , -yarm_tilt = 0, -elevation = -6.574, -mode='pc') - -V1 = GroundBased2G('V1', -latitude = (43 + 37. / 60 + 53.0921 / 3600) * DEG_TO_RAD, -longitude = (10 + 30. / 60 + 16.1887 / 3600) * DEG_TO_RAD, -xarm_azimuth = 243. * DEG_TO_RAD, -yarm_azimuth = 333. * DEG_TO_RAD, -xarm_tilt = 0 , -yarm_tilt = 0, -elevation = 51.884, -mode='pc') \ No newline at end of file + +H1 = GroundBased2G( + "H1", + latitude=(46 + 27.0 / 60 + 18.528 / 3600) * DEG_TO_RAD, + longitude=-(119 + 24.0 / 60 + 27.5657 / 3600) * DEG_TO_RAD, + xarm_azimuth=125.9994 * DEG_TO_RAD, + yarm_azimuth=215.9994 * DEG_TO_RAD, + xarm_tilt=-6.195e-4, + yarm_tilt=1.25e-5, + elevation=142.554, + mode="pc", +) + +L1 = GroundBased2G( + "L1", + latitude=(30 + 33.0 / 60 + 46.4196 / 3600) * DEG_TO_RAD, + longitude=-(90 + 46.0 / 60 + 27.2654 / 3600) * DEG_TO_RAD, + xarm_azimuth=197.7165 * DEG_TO_RAD, + yarm_azimuth=287.7165 * DEG_TO_RAD, + xarm_tilt=0, + yarm_tilt=0, + elevation=-6.574, + mode="pc", +) + +V1 = GroundBased2G( + "V1", + latitude=(43 + 37.0 / 60 + 53.0921 / 3600) * DEG_TO_RAD, + longitude=(10 + 30.0 / 60 + 16.1887 / 3600) * DEG_TO_RAD, + xarm_azimuth=243.0 * DEG_TO_RAD, + yarm_azimuth=333.0 * DEG_TO_RAD, + xarm_tilt=0, + yarm_tilt=0, + elevation=51.884, + mode="pc", +) diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 5e0a021d..1b1406ac 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -25,6 +25,8 @@ class LikelihoodBase(ABC): a given set of parameters. """ + _model: object + _data: object @property def model(self): From d323f6daef453c888a5bc321005562b4ee463ae3 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 2 Dec 2023 14:48:13 -0500 Subject: [PATCH 11/17] Change doc string style to numpy --- src/jimgw/jim.py | 13 +++++++++---- src/jimgw/likelihood.py | 26 +++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 12aa89c1..ad5a3152 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -109,11 +109,16 @@ def get_samples(self, training: bool = False) -> dict: """ Get the samples from the sampler - Args: - training (bool, optional): If True, return the training samples. Defaults to False. + Parameters + ---------- + training : bool, optional + Whether to get the training samples or the production samples, by default False + + Returns + ------- + dict + Dictionary of samples - Returns: - Array: Samples """ if training: chains = self.Sampler.get_sampler_state(training=True)["chains"] diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 1b1406ac..5828202c 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -25,6 +25,7 @@ class LikelihoodBase(ABC): a given set of parameters. """ + _model: object _data: object @@ -375,7 +376,30 @@ def max_phase_diff(f, f_low, f_high, chi=1): f_star[gamma >= 0] = f_high return 2 * np.pi * chi * np.sum((f / f_star) ** gamma * np.sign(gamma), axis=1) - def make_binning_scheme(self, freqs: Float[Array, "dim"], n_bins, chi=1): + def make_binning_scheme( + self, freqs: Float[Array, "dim"], n_bins: int, chi: float = 1 + ) -> tuple[Float[Array, "n_bins+1"], Float[Array, "n_bins"]]: + """ + Make a binning scheme based on the maximum phase difference between the + frequencies in the array. + + Parameters + ---------- + freqs: Float[Array, "dim"] + Array of frequencies to be binned. + n_bins: int + Number of bins to be used. + chi: float = 1 + The chi parameter used in the phase difference calculation. + + Returns + ------- + f_bins: Float[Array, "n_bins+1"] + The bin edges. + f_bins_center: Float[Array, "n_bins"] + The bin centers. + """ + phase_diff_array = self.max_phase_diff(freqs, freqs[0], freqs[-1], chi=chi) bin_f = interp1d(phase_diff_array, freqs) f_bins = np.array([]) From f7363845e29086bdc7a7c269a33aa1c60f495dbb Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 2 Dec 2023 15:21:19 -0500 Subject: [PATCH 12/17] Remove unnecessary masking of the original data --- src/jimgw/likelihood.py | 39 ++++++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 5828202c..cb87f56e 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -157,6 +157,8 @@ def __init__( detectors, waveform, trigger_time, duration, post_trigger_duration ) + print("Initializing heterodyned likelihood..") + # Get the original frequency grid assert jnp.all( @@ -176,6 +178,8 @@ def __init__( ) self.freq_grid_low = freq_grid[:-1] + print("Finding reference parameters..") + self.ref_params = self.maximize_likelihood( bounds=bounds, prior=prior, set_nwalkers=n_walkers, n_loops=n_loops ) @@ -224,18 +228,15 @@ def get_mask(f: Array, f_min: float, f_max: float) -> Array: mask[index_f_min : index_f_max + 1] = True return mask - mask_original = get_mask(frequency_original, f_min, f_max) mask_heterodyne_low = get_mask(self.freq_grid_low, f_min, f_max) mask_heterodyne_center = get_mask(self.freq_grid_center, f_min, f_max) # Apply the mask for frequencies to both polarization modes # and for all waveforms currently used - for mode in ["p", "c"]: - h_sky[mode] = h_sky[mode][mask_original] + for mode in h_sky.keys(): h_sky_low[mode] = h_sky_low[mode][mask_heterodyne_low] h_sky_center[mode] = h_sky_center[mode][mask_heterodyne_center] - frequency_original = frequency_original[mask_original] freq_grid = freq_grid[get_mask(freq_grid, f_min, f_max)] self.freq_grid_low = self.freq_grid_low[mask_heterodyne_low] self.freq_grid_center = self.freq_grid_center[mask_heterodyne_center] @@ -265,7 +266,6 @@ def get_mask(f: Array, f_min: float, f_max: float) -> Array: for detector in self.detectors: # Also apply the mask of frequencies to the strain data - detector.data = detector.data[mask_original] # Get the reference waveforms waveform_ref = ( detector.fd_response(frequency_original, h_sky, self.ref_params) @@ -369,7 +369,32 @@ def evaluate_original( return log_likelihood @staticmethod - def max_phase_diff(f, f_low, f_high, chi=1): + def max_phase_diff( + f: Float[Array, "n_dim"], + f_low: float, + f_high: float, + chi: float =1, + ): + """ + Compute the maximum phase difference between the frequencies in the array. + + Parameters + ---------- + f: Float[Array, "n_dim"] + Array of frequencies to be binned. + f_low: float + Lower frequency bound. + f_high: float + Upper frequency bound. + chi: float + Power law index. + + Returns + ------- + Float[Array, "n_dim"] + Maximum phase difference between the frequencies in the array. + """ + gamma = np.arange(-5, 6, 1) / 3.0 f = np.repeat(f[:, None], len(gamma), axis=1) f_star = np.repeat(f_low, len(gamma)) @@ -377,7 +402,7 @@ def max_phase_diff(f, f_low, f_high, chi=1): return 2 * np.pi * chi * np.sum((f / f_star) ** gamma * np.sign(gamma), axis=1) def make_binning_scheme( - self, freqs: Float[Array, "dim"], n_bins: int, chi: float = 1 + self, freqs: Float[Array, "n_dim"], n_bins: int, chi: float = 1 ) -> tuple[Float[Array, "n_bins+1"], Float[Array, "n_bins"]]: """ Make a binning scheme based on the maximum phase difference between the From f29a96384924c689adcfc6216396ec4d02c15b31 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 2 Dec 2023 15:35:54 -0500 Subject: [PATCH 13/17] Move get mask outside init --- src/jimgw/likelihood.py | 52 +++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index cb87f56e..8514d450 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -208,28 +208,10 @@ def __init__( f_min = jnp.min(f_valid) # TODO replace this - def get_mask(f: Array, f_min: float, f_max: float) -> Array: - """Slice an array f by containing all elements in f - that are greater or equal to f_min, and all elements smaller than or equal - to f_max, and the element just right after that. - - Args: - f (Array): Frequency array to be sliced - f_min (float): Min frequency to be included - f_max (float): Max frequency to be included - - Returns: - Array: Sliced array. - """ - mask = np.array([False for value in f]) - index_f_min = np.argwhere(f >= f_min).flatten()[0] - index_f_max = np.argwhere(f <= f_max).flatten()[-1] - index_f_max = min(index_f_max + 1, len(f) - 1) - mask[index_f_min : index_f_max + 1] = True - return mask - - mask_heterodyne_low = get_mask(self.freq_grid_low, f_min, f_max) - mask_heterodyne_center = get_mask(self.freq_grid_center, f_min, f_max) + + + mask_heterodyne_low = self.get_mask(self.freq_grid_low, f_min, f_max) + mask_heterodyne_center = self.get_mask(self.freq_grid_center, f_min, f_max) # Apply the mask for frequencies to both polarization modes # and for all waveforms currently used @@ -237,7 +219,7 @@ def get_mask(f: Array, f_min: float, f_max: float) -> Array: h_sky_low[mode] = h_sky_low[mode][mask_heterodyne_low] h_sky_center[mode] = h_sky_center[mode][mask_heterodyne_center] - freq_grid = freq_grid[get_mask(freq_grid, f_min, f_max)] + freq_grid = freq_grid[self.get_mask(freq_grid, f_min, f_max)] self.freq_grid_low = self.freq_grid_low[mask_heterodyne_low] self.freq_grid_center = self.freq_grid_center[mask_heterodyne_center] @@ -265,7 +247,6 @@ def get_mask(f: Array, f_min: float, f_max: float) -> Array: ) for detector in self.detectors: - # Also apply the mask of frequencies to the strain data # Get the reference waveforms waveform_ref = ( detector.fd_response(frequency_original, h_sky, self.ref_params) @@ -373,7 +354,7 @@ def max_phase_diff( f: Float[Array, "n_dim"], f_low: float, f_high: float, - chi: float =1, + chi: float = 1, ): """ Compute the maximum phase difference between the frequencies in the array. @@ -472,6 +453,27 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center): B1_array = jnp.array(B1_array) return A0_array, A1_array, B0_array, B1_array + @staticmethod + def get_mask(f: Float[Array, "f"], f_min: float, f_max: float) -> Float[Array, "small_f"]: + """Slice an array f by containing all elements in f + that are greater or equal to f_min, and all elements smaller than or equal + to f_max, and the element just right after that. + + Args: + f (Array): Frequency array to be sliced + f_min (float): Min frequency to be included + f_max (float): Max frequency to be included + + Returns: + Array: Sliced array. + """ + mask = np.zeros(f.shape).astype(bool) + index_f_min = np.argwhere(f >= f_min).flatten()[0] + index_f_max = np.argwhere(f <= f_max).flatten()[-1] + index_f_max = min(index_f_max + 1, len(f) - 1) + mask[index_f_min : index_f_max + 1] = True + return mask + def maximize_likelihood( self, bounds: tuple[Array, Array], From e76ad23d0277d99a88661500b0c489bb26b4908a Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 2 Dec 2023 17:18:22 -0500 Subject: [PATCH 14/17] Heterodyne likelihood now can evaluate on GW150914, need to test on PE result still tho. --- src/jimgw/likelihood.py | 60 +++++++++++++---------------------------- 1 file changed, 18 insertions(+), 42 deletions(-) diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 8514d450..260e195d 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -146,7 +146,7 @@ def __init__( waveform: Waveform, prior: Prior, bounds: tuple[Array, Array], - n_bins: int = 101, + n_bins: int = 100, trigger_time: float = 0, duration: float = 4, post_trigger_duration: float = 2, @@ -196,33 +196,27 @@ def __init__( self.B1_array = {} h_sky = self.waveform(frequency_original, self.ref_params) - h_sky_low = self.waveform(self.freq_grid_low, self.ref_params) - h_sky_center = self.waveform(self.freq_grid_center, self.ref_params) + # Get frequency masks to be applied, for both original # and heterodyne frequency grid + h_amp = jnp.sum(jnp.array([jnp.abs(h_sky[key]) for key in h_sky.keys()]),axis = 0) f_valid = frequency_original[ - jnp.where((jnp.abs(h_sky["p"]) + jnp.abs(h_sky["c"])) > 0)[0] + jnp.where(h_amp > 0)[0] ] f_max = jnp.max(f_valid) f_min = jnp.min(f_valid) - # TODO replace this - - - mask_heterodyne_low = self.get_mask(self.freq_grid_low, f_min, f_max) - mask_heterodyne_center = self.get_mask(self.freq_grid_center, f_min, f_max) - - # Apply the mask for frequencies to both polarization modes - # and for all waveforms currently used - for mode in h_sky.keys(): - h_sky_low[mode] = h_sky_low[mode][mask_heterodyne_low] - h_sky_center[mode] = h_sky_center[mode][mask_heterodyne_center] - - freq_grid = freq_grid[self.get_mask(freq_grid, f_min, f_max)] + mask_heterodyne_grid = jnp.where((freq_grid <= f_max)&(freq_grid >= f_min))[0] + mask_heterodyne_low = jnp.where((self.freq_grid_low <= f_max)&(self.freq_grid_low >= f_min))[0] + mask_heterodyne_center = jnp.where((self.freq_grid_center <= f_max)&(self.freq_grid_center >= f_min))[0] + freq_grid = freq_grid[mask_heterodyne_grid] self.freq_grid_low = self.freq_grid_low[mask_heterodyne_low] self.freq_grid_center = self.freq_grid_center[mask_heterodyne_center] + h_sky_low = self.waveform(self.freq_grid_low, self.ref_params) + h_sky_center = self.waveform(self.freq_grid_center, self.ref_params) + # Get phase shifts to align time of coalescence align_time = jnp.exp( -1j @@ -246,6 +240,7 @@ def __init__( * (self.epoch + self.ref_params["t_c"]) ) + for detector in self.detectors: # Get the reference waveforms waveform_ref = ( @@ -270,10 +265,12 @@ def __init__( freq_grid, self.freq_grid_center, ) - self.A0_array[detector.name] = A0 - self.A1_array[detector.name] = A1 - self.B0_array[detector.name] = B0 - self.B1_array[detector.name] = B1 + + self.A0_array[detector.name] = A0[mask_heterodyne_center] + self.A1_array[detector.name] = A1[mask_heterodyne_center] + self.B0_array[detector.name] = B0[mask_heterodyne_center] + self.B1_array[detector.name] = B1[mask_heterodyne_center] + def evaluate(self, params: Array, data: dict) -> float: log_likelihood = 0 @@ -453,27 +450,6 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center): B1_array = jnp.array(B1_array) return A0_array, A1_array, B0_array, B1_array - @staticmethod - def get_mask(f: Float[Array, "f"], f_min: float, f_max: float) -> Float[Array, "small_f"]: - """Slice an array f by containing all elements in f - that are greater or equal to f_min, and all elements smaller than or equal - to f_max, and the element just right after that. - - Args: - f (Array): Frequency array to be sliced - f_min (float): Min frequency to be included - f_max (float): Max frequency to be included - - Returns: - Array: Sliced array. - """ - mask = np.zeros(f.shape).astype(bool) - index_f_min = np.argwhere(f >= f_min).flatten()[0] - index_f_max = np.argwhere(f <= f_max).flatten()[-1] - index_f_max = min(index_f_max + 1, len(f) - 1) - mask[index_f_min : index_f_max + 1] = True - return mask - def maximize_likelihood( self, bounds: tuple[Array, Array], From a1b825aab4a791534d4ba6fa6f2216f1038f1a92 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 2 Dec 2023 17:19:09 -0500 Subject: [PATCH 15/17] stylzing --- src/jimgw/detector.py | 6 +-- src/jimgw/jim.py | 91 ++++++++++++++++++++++++++--------------- src/jimgw/likelihood.py | 21 +++++----- src/jimgw/waveform.py | 4 +- 4 files changed, 75 insertions(+), 47 deletions(-) diff --git a/src/jimgw/detector.py b/src/jimgw/detector.py index dd9cbcd3..d7580335 100644 --- a/src/jimgw/detector.py +++ b/src/jimgw/detector.py @@ -92,7 +92,7 @@ def _get_arm(lat, lon, tilt, azimuth): """ Construct detector-arm vectors in Earth-centric Cartesian coordinates. - Arguments + Parameters --------- lat : float vertex latitude in rad. @@ -254,7 +254,7 @@ def delay_from_geocenter(self, ra: float, dec: float, gmst: float) -> float: https://lscsoft.docs.ligo.org/lalsuite/lal/group___time_delay__h.html - Arguments + Parameters --------- ra : float right ascension of the source in rad. @@ -289,7 +289,7 @@ def antenna_pattern(self, ra: float, dec: float, psi: float, gmst: float) -> dic given polarization is the dyadic product between the detector tensor and the corresponding polarization tensor. - Arguments + Parameters --------- ra : float source right ascension in radians. diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index ad5a3152..9f4cae5b 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -9,10 +9,11 @@ import jax import jax.numpy as jnp + class Jim(object): """ Master class for interfacing with flowMC - + """ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): @@ -23,24 +24,29 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): rng_key_set = initialize_rng_keys(n_chains, seed=seed) num_layers = kwargs.get("num_layers", 10) - hidden_size = kwargs.get("hidden_size", [128,128]) + hidden_size = kwargs.get("hidden_size", [128, 128]) num_bins = kwargs.get("num_bins", 8) local_sampler_arg = kwargs.get("local_sampler_arg", {}) - local_sampler = MALA(self.posterior, True, local_sampler_arg) # Remember to add routine to find automated mass matrix + local_sampler = MALA( + self.posterior, True, local_sampler_arg + ) # Remember to add routine to find automated mass matrix - model = MaskedCouplingRQSpline(self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1]) + model = MaskedCouplingRQSpline( + self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1] + ) self.Sampler = Sampler( - self.Prior.n_dim, - rng_key_set, - None, - local_sampler, - model, - **kwargs) - - - def maximize_likelihood(self, bounds: tuple[Array,Array], set_nwalkers: int = 100, n_loops: int = 2000, seed = 92348): + self.Prior.n_dim, rng_key_set, None, local_sampler, model, **kwargs + ) + + def maximize_likelihood( + self, + bounds: tuple[Array, Array], + set_nwalkers: int = 100, + n_loops: int = 2000, + seed=92348, + ): bounds = jnp.array(bounds).T key = jax.random.PRNGKey(seed) set_nwalkers = set_nwalkers @@ -53,17 +59,20 @@ def maximize_likelihood(self, bounds: tuple[Array,Array], set_nwalkers: int = 10 print("Done compiling") print("Starting the optimizer") - optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose = True) + optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose=True) state = optimizer.optimize(y, bounds, n_loops=n_loops) best_fit = optimizer.get_result()[0] return best_fit def posterior(self, params: Array, data: dict): - named_params = self.Prior.add_name(params, transform_name=True, transform_value=True) - return self.Likelihood.evaluate(named_params, data) + self.Prior.log_prob(params) - - def sample(self, key: jax.random.PRNGKey, - initial_guess: Array = None): + named_params = self.Prior.add_name( + params, transform_name=True, transform_value=True + ) + return self.Likelihood.evaluate(named_params, data) + self.Prior.log_prob( + params + ) + + def sample(self, key: jax.random.PRNGKey, initial_guess: Array = None): if initial_guess is None: initial_guess = self.Prior.sample(key, self.Sampler.n_chains) self.Sampler.sample(initial_guess, None) @@ -89,21 +98,39 @@ def print_summary(self): production_global_acceptance: Array = production_summary["global_accs"] print("Training summary") - print('=' * 10) + print("=" * 10) for index in range(len(self.Prior.naming)): - print(f"{self.Prior.naming[index]}: {training_chain[:, :, index].mean():.3f} +/- {training_chain[:, :, index].std():.3f}") - print(f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}") - print(f"Local acceptance: {training_local_acceptance.mean():.3f} +/- {training_local_acceptance.std():.3f}") - print(f"Global acceptance: {training_global_acceptance.mean():.3f} +/- {training_global_acceptance.std():.3f}") - print(f"Max loss: {training_loss.max():.3f}, Min loss: {training_loss.min():.3f}") + print( + f"{self.Prior.naming[index]}: {training_chain[:, :, index].mean():.3f} +/- {training_chain[:, :, index].std():.3f}" + ) + print( + f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}" + ) + print( + f"Local acceptance: {training_local_acceptance.mean():.3f} +/- {training_local_acceptance.std():.3f}" + ) + print( + f"Global acceptance: {training_global_acceptance.mean():.3f} +/- {training_global_acceptance.std():.3f}" + ) + print( + f"Max loss: {training_loss.max():.3f}, Min loss: {training_loss.min():.3f}" + ) print("Production summary") - print('=' * 10) + print("=" * 10) for index in range(len(self.Prior.naming)): - print(f"{self.Prior.naming[index]}: {production_chain[:, :, index].mean():.3f} +/- {production_chain[:, :, index].std():.3f}") - print(f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}") - print(f"Local acceptance: {production_local_acceptance.mean():.3f} +/- {production_local_acceptance.std():.3f}") - print(f"Global acceptance: {production_global_acceptance.mean():.3f} +/- {production_global_acceptance.std():.3f}") + print( + f"{self.Prior.naming[index]}: {production_chain[:, :, index].mean():.3f} +/- {production_chain[:, :, index].std():.3f}" + ) + print( + f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}" + ) + print( + f"Local acceptance: {production_local_acceptance.mean():.3f} +/- {production_local_acceptance.std():.3f}" + ) + print( + f"Global acceptance: {production_global_acceptance.mean():.3f} +/- {production_global_acceptance.std():.3f}" + ) def get_samples(self, training: bool = False) -> dict: """ @@ -125,8 +152,8 @@ def get_samples(self, training: bool = False) -> dict: else: chains = self.Sampler.get_sampler_state(training=False)["chains"] - chains = self.Prior.add_name(chains.transpose(2,0,1), transform_name=True) + chains = self.Prior.add_name(chains.transpose(2, 0, 1), transform_name=True) return chains def plot(self): - pass \ No newline at end of file + pass diff --git a/src/jimgw/likelihood.py b/src/jimgw/likelihood.py index 260e195d..7bcdd6c0 100644 --- a/src/jimgw/likelihood.py +++ b/src/jimgw/likelihood.py @@ -197,19 +197,22 @@ def __init__( h_sky = self.waveform(frequency_original, self.ref_params) - # Get frequency masks to be applied, for both original # and heterodyne frequency grid - h_amp = jnp.sum(jnp.array([jnp.abs(h_sky[key]) for key in h_sky.keys()]),axis = 0) - f_valid = frequency_original[ - jnp.where(h_amp > 0)[0] - ] + h_amp = jnp.sum( + jnp.array([jnp.abs(h_sky[key]) for key in h_sky.keys()]), axis=0 + ) + f_valid = frequency_original[jnp.where(h_amp > 0)[0]] f_max = jnp.max(f_valid) f_min = jnp.min(f_valid) - mask_heterodyne_grid = jnp.where((freq_grid <= f_max)&(freq_grid >= f_min))[0] - mask_heterodyne_low = jnp.where((self.freq_grid_low <= f_max)&(self.freq_grid_low >= f_min))[0] - mask_heterodyne_center = jnp.where((self.freq_grid_center <= f_max)&(self.freq_grid_center >= f_min))[0] + mask_heterodyne_grid = jnp.where((freq_grid <= f_max) & (freq_grid >= f_min))[0] + mask_heterodyne_low = jnp.where( + (self.freq_grid_low <= f_max) & (self.freq_grid_low >= f_min) + )[0] + mask_heterodyne_center = jnp.where( + (self.freq_grid_center <= f_max) & (self.freq_grid_center >= f_min) + )[0] freq_grid = freq_grid[mask_heterodyne_grid] self.freq_grid_low = self.freq_grid_low[mask_heterodyne_low] self.freq_grid_center = self.freq_grid_center[mask_heterodyne_center] @@ -240,7 +243,6 @@ def __init__( * (self.epoch + self.ref_params["t_c"]) ) - for detector in self.detectors: # Get the reference waveforms waveform_ref = ( @@ -271,7 +273,6 @@ def __init__( self.B0_array[detector.name] = B0[mask_heterodyne_center] self.B1_array[detector.name] = B1[mask_heterodyne_center] - def evaluate(self, params: Array, data: dict) -> float: log_likelihood = 0 frequencies_low = self.freq_grid_low diff --git a/src/jimgw/waveform.py b/src/jimgw/waveform.py index c94b81ad..11220021 100644 --- a/src/jimgw/waveform.py +++ b/src/jimgw/waveform.py @@ -9,7 +9,7 @@ class Waveform(ABC): def __init__(self): return NotImplemented - def __call__(self, axis: Array, params: Array) -> Array: + def __call__(self, axis: Array, params: Array) -> dict: return NotImplemented @@ -47,7 +47,7 @@ class RippleIMRPhenomPv2(Waveform): def __init__(self, f_ref: float = 20.0): self.f_ref = f_ref - def __call__(self, frequency: Array, params: dict) -> Array: + def __call__(self, frequency: Array, params: dict) -> dict: output = {} theta = [ params["M_c"], From 7771df1980e07f6fab133b2004b9d3802ae1ced8 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sun, 3 Dec 2023 10:19:23 -0500 Subject: [PATCH 16/17] tested GW150914 --- example/GW150914.py | 22 +++++++--- example/GW150914_heterodyne.py | 80 ++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 5 deletions(-) create mode 100644 example/GW150914_heterodyne.py diff --git a/example/GW150914.py b/example/GW150914.py index 949fee63..f9f1e746 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -17,15 +17,17 @@ # first, fetch a 4s segment centered on GW150914 gps = 1126259462.4 -start = gps - 2 -end = gps + 2 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration fmin = 20.0 fmax = 1024.0 ifos = ["H1", "L1"] -H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) prior = Uniform( xmin=[10, 0.125, -1.0, -1.0, 0.0, -0.05, 0.0, -1, 0.0, 0.0, -1.0], @@ -49,6 +51,17 @@ ) likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) +likelihood = HeterodynedTransientLikelihoodFD( + [H1, L1], + prior=prior, + bounds=[prior.xmin, prior.xmax], + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=duration, + post_trigger_duration=post_trigger_duration, + n_loops=300 +) + mass_matrix = jnp.eye(11) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[5, 5].set(1e-3) @@ -74,5 +87,4 @@ local_sampler_arg=local_sampler_arg, ) -jim.maximize_likelihood([prior.xmin, prior.xmax]) jim.sample(jax.random.PRNGKey(42)) diff --git a/example/GW150914_heterodyne.py b/example/GW150914_heterodyne.py new file mode 100644 index 00000000..c9d5df9f --- /dev/null +++ b/example/GW150914_heterodyne.py @@ -0,0 +1,80 @@ +import time +from jimgw.jim import Jim +from jimgw.detector import H1, L1 +from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD +from jimgw.waveform import RippleIMRPhenomD +from jimgw.prior import Uniform +import jax.numpy as jnp +import jax + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration +fmin = 20.0 +fmax = 1024.0 + +ifos = ["H1", "L1"] + +H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +prior = Uniform( + xmin=[10, 0.125, -1.0, -1.0, 0.0, -0.05, 0.0, -1, 0.0, 0.0, -1.0], + xmax=[80.0, 1.0, 1.0, 1.0, 2000.0, 0.05, 2 * jnp.pi, 1.0, jnp.pi, 2 * jnp.pi, 1.0], + naming=[ + "M_c", + "q", + "s1_z", + "s2_z", + "d_L", + "t_c", + "phase_c", + "cos_iota", + "psi", + "ra", + "sin_dec", + ], + transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2), + "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), + "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec +) +likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) + + +mass_matrix = jnp.eye(11) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[5, 5].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + +jim = Jim( + likelihood, + prior, + n_loop_training=100, + n_loop_production=10, + n_local_steps=150, + n_global_steps=150, + n_chains=500, + n_epochs=50, + learning_rate=0.001, + max_samples=45000, + momentum=0.9, + batch_size=50000, + use_global=True, + keep_quantile=0.0, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, +) + +jim.sample(jax.random.PRNGKey(42)) From cf7e7afa489101e456b90be11aed6210a717749b Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sun, 3 Dec 2023 10:21:57 -0500 Subject: [PATCH 17/17] update GW150914.py --- example/GW150914.py | 11 ----------- example/GW150914_heterodyne.py | 11 ++++++++++- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/example/GW150914.py b/example/GW150914.py index f9f1e746..9c373b6a 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -51,17 +51,6 @@ ) likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) -likelihood = HeterodynedTransientLikelihoodFD( - [H1, L1], - prior=prior, - bounds=[prior.xmin, prior.xmax], - waveform=RippleIMRPhenomD(), - trigger_time=gps, - duration=duration, - post_trigger_duration=post_trigger_duration, - n_loops=300 -) - mass_matrix = jnp.eye(11) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[5, 5].set(1e-3) diff --git a/example/GW150914_heterodyne.py b/example/GW150914_heterodyne.py index c9d5df9f..08e091b6 100644 --- a/example/GW150914_heterodyne.py +++ b/example/GW150914_heterodyne.py @@ -49,8 +49,17 @@ "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec ) -likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) +likelihood = HeterodynedTransientLikelihoodFD( + [H1, L1], + prior=prior, + bounds=[prior.xmin, prior.xmax], + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=duration, + post_trigger_duration=post_trigger_duration, + n_loops=300 +) mass_matrix = jnp.eye(11) mass_matrix = mass_matrix.at[1, 1].set(1e-3)