diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index ad91a56a..16f84429 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -1,6 +1,6 @@ import jax.numpy as jnp from beartype import beartype as typechecker -from jaxtyping import Float, Array, jaxtyped +from jaxtyping import Float, Array, jaxtyped, Bool from astropy.time import Time from jimgw.single_event.detector import GroundBased2G @@ -227,9 +227,26 @@ def __init__( gps_time: Float, ifo: GroundBased2G, freq_ref: Float = None, + with_precession: Bool = False, ): name_mapping = [["phase_c"], ["phase_det"]] - conditional_names = ["ra", "dec", "psi", "iota"] + if with_precession: + conditional_names = [ + "ra", + "dec", + "psi", + "theta_jn", + "phi_jl", + "theta_1", + "theta_2", + "phi_12", + "a_1", + "a_2", + "q", + "M_c", + ] + else: + conditional_names = ["ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -253,16 +270,13 @@ def __init__( and "phi_12" in conditional_names and "a_1" in conditional_names and "a_2" in conditional_names - and "q" in conditional_names and "M_c" in conditional_names and "q" in conditional_names ) ) ) - if "iota" in conditional_names: - self.get_iota = lambda x: x["iota"] - else: + if with_precession: self.get_iota = lambda x: spin_to_iota( x["theta_jn"], x["phi_jl"], @@ -276,6 +290,8 @@ def __init__( self.freq_ref, 0.0, ) + else: + self.get_iota = lambda x: x["iota"] @jnp.vectorize def _calc_R_det_arg(ra, dec, psi, iota, gmst): @@ -334,9 +350,26 @@ def __init__( dL_min: Float, dL_max: Float, freq_ref: Float = None, + with_precession: Bool = False, ): name_mapping = [["d_L"], ["d_hat_unbounded"]] - conditional_names = ["M_c", "ra", "dec", "psi", "iota"] + if with_precession: + conditional_names = [ + "M_c", + "ra", + "dec", + "psi", + "theta_jn", + "phi_jl", + "theta_1", + "theta_2", + "phi_12", + "a_1", + "a_2", + "q", + ] + else: + conditional_names = ["M_c", "ra", "dec", "psi", "iota"] super().__init__(name_mapping, conditional_names) self.gmst = ( @@ -368,9 +401,7 @@ def __init__( and "M_c" in conditional_names ) - if "iota" in conditional_names: - self.get_iota = lambda x: x["iota"] - else: + if with_precession: self.get_iota = lambda x: spin_to_iota( x["theta_jn"], x["phi_jl"], @@ -384,6 +415,8 @@ def __init__( self.freq_ref, 0.0, ) + else: + self.get_iota = lambda x: x["iota"] @jnp.vectorize def _calc_R_dets(ra, dec, psi, iota):