Skip to content

Commit

Permalink
Update arrival time transform
Browse files Browse the repository at this point in the history
  • Loading branch information
tsunhopang committed Aug 19, 2024
1 parent 6993dd9 commit ff65fcf
Showing 1 changed file with 42 additions and 16 deletions.
58 changes: 42 additions & 16 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,38 +196,64 @@ class GeocentricArrivalTimeToDetectorArrivalTimeTransform(

gmst: Float
ifo: GroundBased2G
tc_min: Float
tc_max: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
conditional_names: list[str],
gps_time: Float,
ifo: GroundBased2G,
tc_min: Float,
tc_max: Float,
):
super().__init__(name_mapping, conditional_names)

self.gmst = (
Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad
)
self.ifo = ifo
self.tc_min = tc_min
self.tc_max = tc_max

assert "t_c" in name_mapping[0] and "t_det" in name_mapping[1]
assert "t_c" in name_mapping[0] and "t_det_unbounded" in name_mapping[1]
assert "ra" in conditional_names and "dec" in conditional_names

@jnp.vectorize
def time_delay(ra, dec, gmst):
return self.ifo.delay_from_geocenter(ra, dec, gmst)

def named_transform(x):
t_det = x["t_c"] + jnp.vectorize(self.ifo.delay_from_geocenter)(
x["ra"], x["dec"], self.gmst
)

time_shift = time_delay(x["ra"], x["dec"], self.gmst)

t_det = x["t_c"] + time_shift
t_det_min = self.tc_min + time_shift
t_det_max = self.tc_max + time_shift

y = (t_det - t_det_min) / (t_det_max - t_det_min)
t_det_unbounded = jnp.log(y / (1.0 - y))
return {
"t_det": t_det,
"t_det_unbounded": t_det_unbounded,
}

self.transform_func = named_transform

def named_inverse_transform(x):
t_c = x["t_det"] - jnp.vectorize(self.ifo.delay_from_geocenter)(

time_shift = jnp.vectorize(self.ifo.delay_from_geocenter)(
x["ra"], x["dec"], self.gmst
)

t_det_min = self.tc_min + time_shift
t_det_max = self.tc_max + time_shift
t_det = (t_det_max - t_det_min) / (
1.0 + jnp.exp(-x["t_det_unbounded"])
) + t_det_min

t_c = t_det - time_shift

return {
"t_c": t_c,
}
Expand Down Expand Up @@ -328,26 +354,26 @@ class DistanceToSNRWeightedDistanceTransform(ConditionalBijectiveTransform):

gmst: Float
ifos: list[GroundBased2G]
d_L_min: Float
d_L_max: Float
dL_min: Float
dL_max: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
conditional_names: list[str],
gps_time: Float,
ifos: list[GroundBased2G],
d_L_min: Float,
d_L_max: Float,
dL_min: Float,
dL_max: Float,
):
super().__init__(name_mapping, conditional_names)

self.gmst = (
Time(gps_time, format="gps").sidereal_time("apparent", "greenwich").rad
)
self.ifos = ifos
self.d_L_min = d_L_min
self.d_L_max = d_L_max
self.dL_min = dL_min
self.dL_max = dL_max

assert "d_L" in name_mapping[0] and "d_hat_unbounded" in name_mapping[1]
assert (
Expand Down Expand Up @@ -381,8 +407,8 @@ def named_transform(x):
scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets
d_hat = scale_factor * d_L

d_hat_min = scale_factor * self.d_L_min
d_hat_max = scale_factor * self.d_L_max
d_hat_min = scale_factor * self.dL_min
d_hat_max = scale_factor * self.dL_max

y = (d_hat - d_hat_min) / (d_hat_max - d_hat_min)
d_hat_unbounded = jnp.log(y / (1.0 - y))
Expand All @@ -402,8 +428,8 @@ def named_inverse_transform(x):

scale_factor = 1.0 / jnp.power(M_c, 5.0 / 6.0) / R_dets

d_hat_min = scale_factor * self.d_L_min
d_hat_max = scale_factor * self.d_L_max
d_hat_min = scale_factor * self.dL_min
d_hat_max = scale_factor * self.dL_max

d_hat = (d_hat_max - d_hat_min) / (
1.0 + jnp.exp(-d_hat_unbounded)
Expand Down

0 comments on commit ff65fcf

Please sign in to comment.