Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync #147

Merged
merged 49 commits into from
Sep 10, 2024
Merged

Sync #147

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
71f8e95
Update utils.py
xuyuon Jul 9, 2024
722cff0
Update utils.py
xuyuon Jul 24, 2024
3a9cea2
Merge branch 'kazewong:main' into add-spin-transform
xuyuon Jul 24, 2024
dcd4662
Update utils.py
xuyuon Jul 24, 2024
3a92fd6
Update utils.py
xuyuon Jul 24, 2024
b1be707
Update utils.py
xuyuon Jul 24, 2024
f4cfd7f
Update utils.py
xuyuon Jul 24, 2024
f7e287f
Updated utils.py
Jul 24, 2024
4f3be70
Updated runManager.py
xuyuon Jul 24, 2024
6b09d72
updated SingleEventRun
xuyuon Jul 24, 2024
79b07fc
Update the runManager
xuyuon Jul 24, 2024
9e282b7
Updated runManager.py
xuyuon Jul 24, 2024
1828243
Create verify_transform.py
thomasckng Jul 24, 2024
c66c872
Updated runManager.py
xuyuon Jul 24, 2024
c3446b7
Add test for spin transform
thomasckng Jul 24, 2024
0c9028e
Finish spin transform test
thomasckng Jul 24, 2024
374f248
Add likelihood name check
thomasckng Jul 24, 2024
0c49b23
Rename test file
thomasckng Jul 24, 2024
d866fae
Updated runManager.py
xuyuon Jul 24, 2024
1ca2421
Updated runManager.py and Single_event_runManager.py
xuyuon Jul 25, 2024
227fc98
Updated runManager.py
xuyuon Jul 25, 2024
27998ca
Reformatted files
xuyuon Jul 25, 2024
94e5dfa
Updated runManager.py
xuyuon Jul 25, 2024
4267d97
Changed corner plot default setting
xuyuon Jul 25, 2024
93b6dc9
Merge pull request #107 from xuyuon/run-manager
kazewong Jul 26, 2024
dd27330
Updated Test files
xuyuon Jul 26, 2024
a46af6b
Merge pull request #88 from xuyuon/add-spin-transform
kazewong Jul 26, 2024
450772f
Added code to create and reverse transforms for more flexibility
ThibeauWouters Aug 17, 2024
83b0d14
Make sure 1d transforms are created correctly
ThibeauWouters Aug 18, 2024
e803823
Merge pull request #6 from kazewong/98-moving-naming-tracking-into-ji…
xuyuon Aug 21, 2024
c6605f6
Update transforms.py
xuyuon Aug 21, 2024
59dd222
Update runManager.py
xuyuon Aug 21, 2024
702ee20
Update runManager.py
xuyuon Aug 21, 2024
7910785
Merge pull request #137 from xuyuon/fix-jacobian
kazewong Aug 22, 2024
b621dc9
Making the tests lightweight again
ThibeauWouters Aug 22, 2024
c66c65c
Reverted some src code and updated tests accordingly
ThibeauWouters Aug 22, 2024
f0bca2f
Formatting
ThibeauWouters Aug 22, 2024
1d4bb21
Updated the tests to the new transforms
ThibeauWouters Aug 27, 2024
95bc462
New transforms
ThibeauWouters Aug 27, 2024
b3898e1
precommit
ThibeauWouters Aug 27, 2024
45ebe76
Delete test/integration/corner.jpeg
ThibeauWouters Aug 28, 2024
6f31acd
Delete test/integration/diagnostic.jpeg
ThibeauWouters Aug 28, 2024
98e9c0b
Delete test/integration/single_event_runrun_manager_summary.txt
ThibeauWouters Aug 28, 2024
ad90d09
Further consolidation as requested
ThibeauWouters Sep 2, 2024
0e96439
Merge pull request #132 from ThibeauWouters/98-moving-naming-tracking…
kazewong Sep 2, 2024
b003785
Merge branch 'jim-dev' of github.com:kazewong/JaxGW into 98-moving-na…
kazewong Sep 2, 2024
7ed4c54
Move transform test into unit. Note that the current one will fail be…
kazewong Sep 2, 2024
93ea485
Putting transform test in backburner
kazewong Sep 2, 2024
d2c0416
Merge pull request #108 from kazewong/98-moving-naming-tracking-into-…
kazewong Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions example/Single_event_runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,20 @@
]
)


run = SingleEventRun(
seed=0,
detectors=["H1", "L1"],
priors={
"M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0},
"M_c": {"name": "Unconstrained_Uniform", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "MassRatio"},
"s1_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "Uniform", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "Uniform", "xmin": 0.0, "xmax": 2000.0},
"t_c": {"name": "Uniform", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"s1_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2000.0},
"t_c": {"name": "Unconstrained_Uniform", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"cos_iota": {"name": "CosIota"},
"psi": {"name": "Uniform", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"psi": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"sin_dec": {"name": "SinDec"},
},
waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0},
Expand Down Expand Up @@ -90,3 +89,9 @@
)

run_manager = SingleEventPERunManager(run=run)
run_manager.jim.sample(jax.random.PRNGKey(42))

# plot the corner plot and diagnostic plot
run_manager.plot_corner()
run_manager.plot_diagnostic()

79 changes: 78 additions & 1 deletion src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import corner
import numpy as np
import yaml
from astropy.time import Time
from jaxlib.xla_extension import ArrayImpl
Expand Down Expand Up @@ -71,7 +73,8 @@ class SingleEventRun:
str, dict[str, Union[str, float, int, bool]]
] # Transform cannot be included in this way, add it to preset if used often.
jim_parameters: dict[str, Union[str, float, int, bool, dict]]
injection_parameters: dict[str, float]
path: str = "./experiment"
injection_parameters: dict[str, float] = field(default_factory=lambda: {})
injection: bool = False
likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field(
default_factory=lambda: {"name": "TransientLikelihoodFD"}
Expand Down Expand Up @@ -123,6 +126,9 @@ def __init__(self, **kwargs):
print("Neither run instance nor path provided.")
raise ValueError

if self.run.injection and not self.run.injection_parameters:
raise ValueError("Injection mode requires injection parameters.")

local_prior = self.initialize_prior()
local_likelihood = self.initialize_likelihood(local_prior)
self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters)
Expand Down Expand Up @@ -150,6 +156,7 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood:
waveform = self.initialize_waveform()
name = self.run.likelihood_parameters["name"]
assert isinstance(name, str), "Likelihood name must be a string."
assert name in likelihood_presets, f"Likelihood {name} not recognized."
if self.run.injection:
freqs = jnp.linspace(
self.run.data_parameters["f_min"],
Expand Down Expand Up @@ -351,3 +358,73 @@ def plot_data(self, path: str):
plt.ylabel("Amplitude")
plt.legend()
plt.savefig(path)

def sample(self):
self.jim.sample(jax.random.PRNGKey(self.run.seed))

def get_samples(self):
return self.jim.get_samples()

def plot_corner(self, path: str = "corner.jpeg", **kwargs):
"""
plot corner plot of the samples.
"""
plot_datapoint = kwargs.get("plot_datapoints", False)
title_quantiles = kwargs.get("title_quantiles", [0.16, 0.5, 0.84])
show_titles = kwargs.get("show_titles", True)
title_fmt = kwargs.get("title_fmt", ".2E")
use_math_text = kwargs.get("use_math_text", True)

samples = self.jim.get_samples()
param_names = list(samples.keys())
samples = np.array(list(samples.values())).reshape(int(len(param_names)), -1).T
corner.corner(
samples,
labels=param_names,
plot_datapoints=plot_datapoint,
title_quantiles=title_quantiles,
show_titles=show_titles,
title_fmt=title_fmt,
use_math_text=use_math_text,
**kwargs,
)
plt.savefig(path)
plt.close()

def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs):
"""
plot diagnostic plot of the samples.
"""
summary = self.jim.Sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = summary.values()
log_prob = np.array(log_prob)

plt.figure(figsize=(10, 10))
axs = [plt.subplot(2, 2, i + 1) for i in range(4)]
plt.sca(axs[0])
plt.title("log probability")
plt.plot(log_prob.mean(0))
plt.xlabel("iteration")
plt.xlim(0, None)

plt.sca(axs[1])
plt.title("NF loss")
plt.plot(loss_vals.reshape(-1))
plt.xlabel("iteration")
plt.xlim(0, None)

plt.sca(axs[2])
plt.title("Local Acceptance")
plt.plot(local_accs.mean(0))
plt.xlabel("iteration")
plt.xlim(0, None)

plt.sca(axs[3])
plt.title("Global Acceptance")
plt.plot(global_accs.mean(0))
plt.xlabel("iteration")
plt.xlim(0, None)
plt.tight_layout()

plt.savefig(path)
plt.close()
Loading
Loading