Skip to content

Commit

Permalink
Merge pull request #122 from xuyuon/98-moving-naming-tracking-into-ji…
Browse files Browse the repository at this point in the history
…m-class-from-prior-class

Updated jim output
  • Loading branch information
kazewong authored Aug 2, 2024
2 parents 9d9c833 + 6374772 commit 87440db
Show file tree
Hide file tree
Showing 4 changed files with 382 additions and 68 deletions.
85 changes: 18 additions & 67 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def negative_posterior(x: Float[Array, " n_dim"]):
best_fit = optimizer.get_result()[0]
return best_fit

def print_summary(self):
def print_summary(self, transform: bool = True):
"""
Generate summary of the run
Expand All @@ -144,59 +144,29 @@ def print_summary(self):
train_summary = self.sampler.get_sampler_state(training=True)
production_summary = self.sampler.get_sampler_state(training=False)

training_chain = train_summary["chains"].reshape(-1, len(self.parameter_names))
if self.sample_transforms:
# Need rewrite to vectorize
transformed_chain = {}
named_sample = self.add_name(training_chain[0])
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key] = [value]
for sample in training_chain[1:]:
named_sample = self.add_name(sample)
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key].append(value)
training_chain = transformed_chain
else:
training_chain = self.add_name(training_chain)
training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T
training_chain = self.add_name(training_chain)
if transform:
for sample_transform in self.sample_transforms:
training_chain = sample_transform.backward(training_chain)
training_log_prob = train_summary["log_prob"]
training_local_acceptance = train_summary["local_accs"]
training_global_acceptance = train_summary["global_accs"]
training_loss = train_summary["loss_vals"]

production_chain = production_summary["chains"].reshape(
-1, len(self.parameter_names)
)
if self.sample_transforms:
# Need rewrite to vectorize
transformed_chain = {}
named_sample = self.add_name(production_chain[0])
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key] = [value]
for sample in production_chain[1:]:
named_sample = self.add_name(sample)
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key].append(value)
production_chain = transformed_chain
else:
production_chain = self.add_name(production_chain)
production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T
production_chain = self.add_name(production_chain)
if transform:
for sample_transform in self.sample_transforms:
production_chain = sample_transform.backward(production_chain)
production_log_prob = production_summary["log_prob"]
production_local_acceptance = production_summary["local_accs"]
production_global_acceptance = production_summary["global_accs"]

print("Training summary")
print("=" * 10)
for key, value in training_chain.items():
print(
f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}"
)
print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}")
print(
f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}"
)
Expand All @@ -213,9 +183,7 @@ def print_summary(self):
print("Production summary")
print("=" * 10)
for key, value in production_chain.items():
print(
f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}"
)
print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}")
print(
f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}"
)
Expand Down Expand Up @@ -246,28 +214,11 @@ def get_samples(self, training: bool = False) -> dict:
else:
chains = self.sampler.get_sampler_state(training=False)["chains"]

# Need rewrite to output chains instead of flattened samples and vectorize
chains = chains.reshape(-1, len(self.parameter_names))
if self.sample_transforms:
transformed_chain = {}
named_sample = self.add_name(chains[0])
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key] = [value]
for sample in chains[1:]:
named_sample = self.add_name(sample)
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key].append(value)
output = transformed_chain
else:
output = self.add_name(chains)

for key in output.keys():
output[key] = jnp.array(output[key])
return output
chains = chains.transpose(2, 0, 1)
chains = self.add_name(chains)
for sample_transform in self.sample_transforms:
chains = sample_transform.backward(chains)
return chains

def plot(self):
pass
64 changes: 63 additions & 1 deletion src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from astropy.time import Time

from jimgw.single_event.detector import GroundBased2G
from jimgw.transforms import BijectiveTransform
from jimgw.transforms import BijectiveTransform, NtoNTransform
from jimgw.single_event.utils import (
m1_m2_to_Mc_q,
Mc_q_to_m1_m2,
Expand All @@ -15,6 +15,7 @@
ra_dec_to_zenith_azimuth,
zenith_azimuth_to_ra_dec,
euler_rotation,
spin_to_cartesian_spin,
)


Expand Down Expand Up @@ -167,3 +168,64 @@ def named_inverse_transform(x):
return {"ra": ra, "dec": dec}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class SpinToCartesianSpinTransform(NtoNTransform):
"""
Spin to Cartesian spin transformation
"""

freq_ref: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
freq_ref: Float,
):
super().__init__(name_mapping)

self.freq_ref = freq_ref

assert (
"theta_jn" in name_mapping[0]
and "phi_jl" in name_mapping[0]
and "theta_1" in name_mapping[0]
and "theta_2" in name_mapping[0]
and "phi_12" in name_mapping[0]
and "a_1" in name_mapping[0]
and "a_2" in name_mapping[0]
and "iota" in name_mapping[1]
and "s1_x" in name_mapping[1]
and "s1_y" in name_mapping[1]
and "s1_z" in name_mapping[1]
and "s2_x" in name_mapping[1]
and "s2_y" in name_mapping[1]
and "s2_z" in name_mapping[1]
)

def named_transform(x):
iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
x["theta_2"],
x["phi_12"],
x["a_1"],
x["a_2"],
x["M_c"],
x["q"],
self.freq_ref,
x["phase_c"],
)
return {
"iota": iota,
"s1_x": s1x,
"s1_y": s1y,
"s1_z": s1z,
"s2_x": s2x,
"s2_y": s2y,
"s2_z": s2z,
}

self.transform_func = named_transform
160 changes: 160 additions & 0 deletions src/jimgw/single_event/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from jax.scipy.integrate import trapezoid
from jaxtyping import Array, Float

from jimgw.constants import Msun


def inner_product(
h1: Float[Array, " n_sample"],
Expand Down Expand Up @@ -391,6 +393,164 @@ def theta_phi_to_ra_dec(theta: Float, phi: Float, gmst: Float) -> tuple[Float, F
return ra, dec


def spin_to_cartesian_spin(
thetaJN: Float,
phiJL: Float,
theta1: Float,
theta2: Float,
phi12: Float,
chi1: Float,
chi2: Float,
M_c: Float,
q: Float,
fRef: Float,
phiRef: Float,
) -> tuple[Float, Float, Float, Float, Float, Float, Float]:
"""
Transforming the spin parameters
The code is based on the approach used in LALsimulation:
https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html
Parameters:
-------
thetaJN: Float
Zenith angle between the total angular momentum and the line of sight
phiJL: Float
Difference between total and orbital angular momentum azimuthal angles
theta1: Float
Zenith angle between the spin and orbital angular momenta for the primary object
theta2: Float
Zenith angle between the spin and orbital angular momenta for the secondary object
phi12: Float
Difference between the azimuthal angles of the individual spin vector projections
onto the orbital plane
chi1: Float
Primary object aligned spin:
chi2: Float
Secondary object aligned spin:
M_c: Float
The chirp mass
eta: Float
The symmetric mass ratio
fRef: Float
The reference frequency
phiRef: Float
Binary phase at a reference frequency
Returns:
-------
iota: Float
Zenith angle between the orbital angular momentum and the line of sight
S1x: Float
The x-component of the primary spin
S1y: Float
The y-component of the primary spin
S1z: Float
The z-component of the primary spin
S2x: Float
The x-component of the secondary spin
S2y: Float
The y-component of the secondary spin
S2z: Float
The z-component of the secondary spin
"""

def rotate_y(angle, vec):
"""
Rotate the vector (x, y, z) about y-axis
"""
cos_angle = jnp.cos(angle)
sin_angle = jnp.sin(angle)
rotation_matrix = jnp.array(
[[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]]
)
rotated_vec = jnp.dot(rotation_matrix, vec)
return rotated_vec

def rotate_z(angle, vec):
"""
Rotate the vector (x, y, z) about z-axis
"""
cos_angle = jnp.cos(angle)
sin_angle = jnp.sin(angle)
rotation_matrix = jnp.array(
[[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]]
)
rotated_vec = jnp.dot(rotation_matrix, vec)
return rotated_vec

LNh = jnp.array([0.0, 0.0, 1.0])

s1hat = jnp.array(
[
jnp.sin(theta1) * jnp.cos(phiRef),
jnp.sin(theta1) * jnp.sin(phiRef),
jnp.cos(theta1),
]
)
s2hat = jnp.array(
[
jnp.sin(theta2) * jnp.cos(phi12 + phiRef),
jnp.sin(theta2) * jnp.sin(phi12 + phiRef),
jnp.cos(theta2),
]
)

m1, m2 = Mc_q_to_m1_m2(M_c, q)
eta = q / (1 + q) ** 2
v0 = jnp.cbrt((m1 + m2) * Msun * jnp.pi * fRef)

Lmag = ((m1 + m2) * (m1 + m2) * eta / v0) * (1.0 + v0 * v0 * (1.5 + eta / 6.0))
s1 = m1 * m1 * chi1 * s1hat
s2 = m2 * m2 * chi2 * s2hat
J = s1 + s2 + jnp.array([0.0, 0.0, Lmag])

Jhat = J / jnp.linalg.norm(J)
theta0 = jnp.arccos(Jhat[2])
phi0 = jnp.arctan2(Jhat[1], Jhat[0])

# Rotation 1:
s1hat = rotate_z(-phi0, s1hat)
s2hat = rotate_z(-phi0, s2hat)

# Rotation 2:
LNh = rotate_y(-theta0, LNh)
s1hat = rotate_y(-theta0, s1hat)
s2hat = rotate_y(-theta0, s2hat)

# Rotation 3:
LNh = rotate_z(phiJL - jnp.pi, LNh)
s1hat = rotate_z(phiJL - jnp.pi, s1hat)
s2hat = rotate_z(phiJL - jnp.pi, s2hat)

# Compute iota
N = jnp.array([0.0, jnp.sin(thetaJN), jnp.cos(thetaJN)])
iota = jnp.arccos(jnp.dot(N, LNh))

thetaLJ = jnp.arccos(LNh[2])
phiL = jnp.arctan2(LNh[1], LNh[0])

# Rotation 4:
s1hat = rotate_z(-phiL, s1hat)
s2hat = rotate_z(-phiL, s2hat)
N = rotate_z(-phiL, N)

# Rotation 5:
s1hat = rotate_y(-thetaLJ, s1hat)
s2hat = rotate_y(-thetaLJ, s2hat)
N = rotate_y(-thetaLJ, N)

# Rotation 6:
phiN = jnp.arctan2(N[1], N[0])
s1hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s1hat)
s2hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s2hat)

S1 = s1hat * chi1
S2 = s2hat * chi2
return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2]


def zenith_azimuth_to_ra_dec(
zenith: Float, azimuth: Float, gmst: Float, rotation: Float[Array, " 3 3"]
) -> tuple[Float, Float]:
Expand Down
Loading

0 comments on commit 87440db

Please sign in to comment.