Skip to content

Commit

Permalink
Merge pull request #69 from tsunhopang/print_SNR
Browse files Browse the repository at this point in the history
Printing injected SNRs
  • Loading branch information
kazewong authored Feb 16, 2024
2 parents c477887 + c09aa1a commit 9f6299e
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 13 deletions.
10 changes: 6 additions & 4 deletions src/jimgw/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def log_prob(self, x: dict[str, Float]) -> Float:


class Sphere(Prior):

"""
A prior on a sphere represented by Cartesian coordinates.
Expand Down Expand Up @@ -267,7 +266,12 @@ def log_prob(self, x: dict[str, Float]) -> Float:
phi = x[self.naming[1]]
mag = x[self.naming[2]]
output = jnp.where(
(mag > 1) | (mag < 0) | (phi > 2* jnp.pi) | (phi < 0) | (theta > 1) | (theta < -1),
(mag > 1)
| (mag < 0)
| (phi > 2 * jnp.pi)
| (phi < 0)
| (theta > 1)
| (theta < -1),
jnp.zeros_like(0) - jnp.inf,
jnp.log(mag**2 * jnp.sin(x[self.naming[0]])),
)
Expand All @@ -276,7 +280,6 @@ def log_prob(self, x: dict[str, Float]) -> Float:

@jaxtyped
class AlignedSpin(Prior):

"""
Prior distribution for the aligned (z) component of the spin.
Expand Down Expand Up @@ -390,7 +393,6 @@ def log_prob(self, x: dict[str, Float]) -> Float:

@jaxtyped
class PowerLaw(Prior):

"""
A prior following the power-law with alpha in the range [xmin, xmax).
p(x) ~ x^{\alpha}
Expand Down
8 changes: 8 additions & 0 deletions src/jimgw/single_event/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,14 @@ def inject_signal(
signal = self.fd_response(freqs, h_sky, params) * align_time
self.data = signal + noise_real + 1j * noise_imag

# also calculate the optimal SNR and match filter SNR
optimal_SNR = jnp.sqrt(jnp.sum(signal * signal.conj() / var).real)
match_filter_SNR = jnp.sum(self.data * signal.conj() / var) / optimal_SNR

print(f"For detector {self.name}:")
print(f"The injected optimal SNR is {optimal_SNR}")
print(f"The injected match filter SNR is {match_filter_SNR}")

@jaxtyped
def load_psd(
self, freqs: Float[Array, " n_sample"], psd_file: str = ""
Expand Down
15 changes: 7 additions & 8 deletions src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Union
from jimgw.base import RunManager
from dataclasses import dataclass, field, asdict
from jimgw.single_event.likelihood import likelihood_presets, SingleEventLiklihood
Expand Down Expand Up @@ -65,18 +66,18 @@ class SingleEventRun:

detectors: list[str]
priors: dict[
str, dict[str, str | float | int | bool]
str, dict[str, Union[str, float, int, bool]]
] # Transform cannot be included in this way, add it to preset if used often.
jim_parameters: dict[str, str | float | int | bool | dict]
jim_parameters: dict[str, Union[str, float, int, bool, dict]]
injection_parameters: dict[str, float]
injection: bool = False
likelihood_parameters: dict[str, str | float | int | bool | PyTree] = field(
likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field(
default_factory=lambda: {"name": "TransientLikelihoodFD"}
)
waveform_parameters: dict[str, str | float | int | bool] = field(
waveform_parameters: dict[str, Union[str, float, int, bool]] = field(
default_factory=lambda: {"name": ""}
)
data_parameters: dict[str, float | int] = field(
data_parameters: dict[str, Union[float, int]] = field(
default_factory=lambda: {
"trigger_time": 0.0,
"duration": 0,
Expand Down Expand Up @@ -249,9 +250,7 @@ def initialize_waveform(self) -> Waveform:

### Utility functions ###

def get_detector_waveform(
self, params: dict[str, float]
) -> tuple[
def get_detector_waveform(self, params: dict[str, float]) -> tuple[
Float[Array, " n_sample"],
dict[str, Float[Array, " n_sample"]],
dict[str, Float[Array, " n_sample"]],
Expand Down
3 changes: 2 additions & 1 deletion src/jimgw/single_event/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import jax.numpy as jnp
from jax.scipy.integrate import trapezoid
from jax import jit
from jaxtyping import Float, Array

Expand Down Expand Up @@ -34,7 +35,7 @@ def inner_product(
# psd_interp = jnp.interp(frequency, psd_frequency, psd)
df = frequency[1] - frequency[0]
integrand = jnp.conj(h1) * h2 / psd
return 4.0 * jnp.real(jnp.trapz(integrand, dx=df))
return 4.0 * jnp.real(trapezoid(integrand, dx=df))


@jit
Expand Down

0 comments on commit 9f6299e

Please sign in to comment.