Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Run Manager #135

Merged
merged 7 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 53 additions & 55 deletions example/Single_event_runManager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import jax
import jax.numpy as jnp

Expand All @@ -12,57 +11,50 @@
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
mass_matrix = mass_matrix * 3e-3
local_sampler_arg = {"step_size": mass_matrix}
bounds = jnp.array(
[
[10.0, 40.0],
[0.125, 1.0],
[-1.0, 1.0],
[-1.0, 1.0],
[0.0, 2000.0],
[-0.05, 0.05],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
[0.0, jnp.pi],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
]
)

run = SingleEventRun(
seed=0,
detectors=["H1", "L1"],
data_parameters={
"trigger_time": 1126259462.4,
"duration": 4,
"post_trigger_duration": 2,
"f_min": 20.0,
"f_max": 1024.0,
"tukey_alpha": 0.2,
"f_sampling": 4096.0,
},
priors={
"M_c": {"name": "Unconstrained_Uniform", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "MassRatio"},
"s1_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2000.0},
"t_c": {"name": "Unconstrained_Uniform", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"cos_iota": {"name": "CosIota"},
"psi": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"sin_dec": {"name": "SinDec"},
"M_c": {"name": "UniformPrior", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "UniformPrior", "xmin": 0.0, "xmax": 1.0},
"s1_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "UniformPrior", "xmin": 1.0, "xmax": 2000.0},
"t_c": {"name": "UniformPrior", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi},
"iota": {"name": "SinePrior"},
"psi": {"name": "UniformPrior", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi},
"dec": {"name": "CosinePrior"},
},
waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0},
jim_parameters={
"n_loop_training": 10,
"n_loop_production": 10,
"n_local_steps": 15,
"n_global_steps": 15,
"n_chains": 500,
"n_epochs": 10,
"learning_rate": 0.001,
"n_max_examples": 45000,
"momentum": 0.9,
"batch_size": 50000,
"use_global": True,
"keep_quantile": 0.0,
"train_thinning": 1,
"output_thinning": 10,
"local_sampler_arg": local_sampler_arg,
},
likelihood_parameters={"name": "TransientLikelihoodFD", "bounds": bounds},
likelihood_parameters={"name": "TransientLikelihoodFD"},
sample_transforms=[
{"name": "BoundToUnbound", "name_mapping": [["M_c"], ["M_c_unbounded"]], "original_lower_bound": 10.0, "original_upper_bound": 80.0,},
{"name": "BoundToUnbound", "name_mapping": [["q"], ["q_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["s1_z"], ["s1_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["s2_z"], ["s2_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["d_L"], ["d_L_unbounded"]], "original_lower_bound": 1.0, "original_upper_bound": 2000.0,},
{"name": "BoundToUnbound", "name_mapping": [["t_c"], ["t_c_unbounded"]], "original_lower_bound": -0.05, "original_upper_bound": 0.05,},
{"name": "BoundToUnbound", "name_mapping": [["phase_c"], ["phase_c_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["iota"], ["iota_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["psi"], ["psi_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["ra"], ["ra_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
],
likelihood_transforms=[
{"name": "MassRatioToSymmetricMassRatioTransform", "name_mapping": [["q"], ["eta"]]},
],
injection=True,
injection_parameters={
"M_c": 28.6,
Expand All @@ -77,22 +69,28 @@
"ra": 1.2,
"dec": 0.3,
},
data_parameters={
"trigger_time": 1126259462.4,
"duration": 4,
"post_trigger_duration": 2,
"f_min": 20.0,
"f_max": 1024.0,
"tukey_alpha": 0.2,
"f_sampling": 4096.0,
jim_parameters={
"n_loop_training": 100,
"n_loop_production": 20,
"n_local_steps": 10,
"n_global_steps": 1000,
"n_chains": 500,
"n_epochs": 30,
"learning_rate": 1e-4,
"n_max_examples": 30000,
"momentum": 0.9,
"batch_size": 30000,
"use_global": True,
"train_thinning": 1,
"output_thinning": 10,
"local_sampler_arg": local_sampler_arg,
},
)

run_manager = SingleEventPERunManager(run=run)
run_manager.jim.sample(jax.random.PRNGKey(42))
run_manager.sample()

# plot the corner plot and diagnostic plot
run_manager.plot_corner()
run_manager.plot_diagnostic()
run_manager.save_summary()

2 changes: 1 addition & 1 deletion src/jimgw/single_event/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def inject_signal(
h_sky: dict[str, Float[Array, " n_sample"]],
params: dict[str, Float],
psd_file: str = "",
) -> None:
) -> tuple[Float, Float]:
"""
Inject a signal into the detector data.

Expand Down
154 changes: 87 additions & 67 deletions src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,22 @@
from jaxlib.xla_extension import ArrayImpl
from jaxtyping import Array, Float, PyTree

from jimgw import prior
from jimgw import prior, transforms
from jimgw.single_event import prior as single_event_prior
from jimgw.single_event import transforms as single_event_transforms
from jimgw.base import RunManager
from jimgw.jim import Jim
from jimgw.single_event.detector import Detector, detector_preset
from jimgw.single_event.likelihood import SingleEventLiklihood, likelihood_presets
from jimgw.single_event.waveform import Waveform, waveform_preset



def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl):
return dumper.represent_list(data.tolist())


yaml.add_representer(ArrayImpl, jaxarray_representer) # type: ignore

prior_presets = {
"Unconstrained_Uniform": prior.Unconstrained_Uniform,
"Uniform": prior.Uniform,
"Sphere": prior.Sphere,
"AlignedSpin": prior.AlignedSpin,
"PowerLaw": prior.PowerLaw,
"Composite": prior.Composite,
"MassRatio": lambda **kwargs: prior.Uniform(
0.125,
1.0,
naming=["q"],
transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)},
),
"CosIota": lambda **kwargs: prior.Uniform(
-1.0,
1.0,
naming=["cos_iota"],
transforms={
"cos_iota": (
"iota",
lambda params: jnp.arccos(params["cos_iota"]),
)
},
),
"SinDec": lambda **kwargs: prior.Uniform(
-1.0,
1.0,
naming=["sin_dec"],
transforms={
"sin_dec": (
"dec",
lambda params: jnp.arcsin(params["sin_dec"]),
)
},
),
"EarthFrame": prior.EarthFrame,
}


@dataclass
class SingleEventRun:
Expand All @@ -75,7 +38,7 @@ class SingleEventRun:
str, dict[str, Union[str, float, int, bool]]
] # Transform cannot be included in this way, add it to preset if used often.
jim_parameters: dict[str, Union[str, float, int, bool, dict]]
path: str = "./experiment"
path: str = "single_event_run"
injection_parameters: dict[str, float] = field(default_factory=lambda: {})
injection: bool = False
likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field(
Expand All @@ -95,7 +58,12 @@ class SingleEventRun:
"f_sampling": 4096.0,
}
)

sample_transforms: list[dict[str, Union[str, float, int, bool]]] = field(
default_factory=lambda: []
)
likelihood_transforms: list[dict[str, Union[str, float, int, bool]]] = field(
default_factory=lambda: []
)


class SingleEventPERunManager(RunManager):
Expand Down Expand Up @@ -135,7 +103,14 @@ def __init__(self, **kwargs):

local_prior = self.initialize_prior()
local_likelihood = self.initialize_likelihood(local_prior)
self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters)
sample_transforms, likelihood_transforms = self.initialize_transforms()
self.jim = Jim(
local_likelihood,
local_prior,
sample_transforms,
likelihood_transforms,
**self.run.jim_parameters,
)

def save(self, path: str):
output_dict = asdict(self.run)
Expand All @@ -149,7 +124,7 @@ def load_from_path(self, path: str) -> SingleEventRun:

### Initialization functions ###

def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood:
def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLiklihood:
"""
Since prior contains information about types, naming and ranges of parameters,
some of the likelihood class require the prior to be initialized, such as the
Expand Down Expand Up @@ -192,11 +167,11 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood:
key, subkey = jax.random.split(jax.random.PRNGKey(self.run.seed + 1901))
SNRs = []
for detector in detectors:
optimal_SNR,_ = detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore
optimal_SNR, _ = detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore
SNRs.append(optimal_SNR)
key, subkey = jax.random.split(key)
self.SNRs = SNRs

return likelihood_presets[name](
detectors,
waveform,
Expand All @@ -205,23 +180,67 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood:
**self.run.data_parameters,
)

def initialize_prior(self) -> prior.Prior:
def initialize_prior(self) -> prior.CombinePrior:
priors = []
for name, parameters in self.run.priors.items():
if parameters["name"] not in prior_presets:
raise ValueError(f"Prior {name} not recognized.")
if parameters["name"] == "EarthFrame":
priors.append(
prior.EarthFrame(
gps=self.run.data_parameters["trigger_time"],
ifos=self.run.detectors,
assert isinstance(
parameters, dict
), "Prior parameters must be a dictionary."
assert "name" in parameters, "Prior name must be provided."
assert isinstance(parameters["name"], str), "Prior name must be a string."
try:
prior_class = getattr(single_event_prior, parameters["name"])
except AttributeError:
try:
prior_class = getattr(prior, parameters["name"])
except AttributeError:
raise ValueError(f"{parameters['name']} not recognized.")
parameters.pop("name")
priors.append(prior_class(parameter_names=[name], **parameters))
return prior.CombinePrior(priors)

def initialize_transforms(
self,
) -> tuple[list[transforms.BijectiveTransform], list[transforms.NtoMTransform]]:
sample_transforms = []
likelihood_transforms = []
if self.run.sample_transforms:
for transform in self.run.sample_transforms:
assert isinstance(transform, dict), "Transform must be a dictionary."
assert "name" in transform, "Transform name must be provided."
assert isinstance(
transform["name"], str
), "Transform name must be a string."
try:
transform_class = getattr(
single_event_transforms, transform["name"]
)
)
else:
priors.append(
prior_presets[parameters["name"]](naming=[name], **parameters)
)
return prior.Composite(priors)
except AttributeError:
try:
transform_class = getattr(transforms, transform["name"])
except AttributeError:
raise ValueError(f"{transform['name']} not recognized.")
transform.pop("name")
sample_transforms.append(transform_class(**transform))
if self.run.likelihood_transforms:
for transform in self.run.likelihood_transforms:
assert isinstance(transform, dict), "Transform must be a dictionary."
assert "name" in transform, "Transform name must be provided."
assert isinstance(
transform["name"], str
), "Transform name must be a string."
try:
transform_class = getattr(
single_event_transforms, transform["name"]
)
except AttributeError:
try:
transform_class = getattr(transforms, transform["name"])
except AttributeError:
raise ValueError(f"{transform['name']} not recognized.")
transform.pop("name")
likelihood_transforms.append(transform_class(**transform))
return sample_transforms, likelihood_transforms

def initialize_detector(self) -> list[Detector]:
"""
Expand Down Expand Up @@ -403,7 +422,7 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs):
"""
plot diagnostic plot of the samples.
"""
summary = self.jim.Sampler.get_sampler_state(training=True)
summary = self.jim.sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = summary.values()
log_prob = np.array(log_prob)

Expand Down Expand Up @@ -437,11 +456,12 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs):
plt.savefig(path)
plt.close()

def save_summary(self, path: str = "run_manager_summary.txt", **kwargs):
sys.stdout = open(path,'wt')
def save_summary(self, path: str = "", **kwargs):
if path == "":
path = self.run.path + "run_manager_summary.txt"
sys.stdout = open(path, "wt")
self.jim.print_summary()
#print(self.SNRs)
for detector, SNR in zip(self.detectors, self.SNRs):
print('SNR of detector ' + detector + ' is ' + str(SNR))
networkSNR = jnp.sum(jnp.array(self.SNRs)**2) ** (0.5)
print('network SNR is', networkSNR)
print("SNR of detector " + detector + " is " + str(SNR))
networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5)
print("network SNR is", networkSNR)
Loading
Loading