Skip to content

Commit

Permalink
refactored jim for plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibeauWouters committed Feb 3, 2024
1 parent aaf25fc commit e15e705
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,8 @@ H1.txt
L1.txt
V1.txt
test_data

*.png
*.npz
*.pdf
*.txt
4 changes: 4 additions & 0 deletions example/GW150914.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,7 @@
)

jim.sample(jax.random.PRNGKey(42))

jim.print_summary()
jim.Sampler.plot_summary("training")
jim.Sampler.plot_summary("production")
28 changes: 17 additions & 11 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.utils.PRNG_keys import initialize_rng_keys
from flowMC.utils.EvolutionaryOptimizer import EvolutionaryOptimizer
from flowMC.sampler.flowHMC import flowHMC
# from flowMC.sampler.flowHMC import flowHMC

from jimgw.prior import Prior
from jimgw.base import LikelihoodBase
from jimgw.utils.hyperparameters import jim_default_hyperparameters



class Jim(object):
Expand All @@ -23,14 +25,18 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):
self.Likelihood = likelihood
self.Prior = prior

seed = kwargs.get("seed", 0)
n_chains = kwargs.get("n_chains", 20)

rng_key_set = initialize_rng_keys(n_chains, seed=seed)
num_layers = kwargs.get("num_layers", 10)
hidden_size = kwargs.get("hidden_size", [128, 128])
num_bins = kwargs.get("num_bins", 8)

# Set and override any given hyperparameters, and save as attribute
self.hyperparameters = jim_default_hyperparameters
hyperparameter_names = list(self.hyperparameters.keys())

for key, value in kwargs.items():
if key in hyperparameter_names:
self.hyperparameters[key] = value

for key, value in self.hyperparameters.items():
setattr(self, key, value)

self.rng_key_set = initialize_rng_keys(self.hyperparameters["n_chains"], seed=self.hyperparameters["seed"])
local_sampler_arg = kwargs.get("local_sampler_arg", {})

local_sampler = MALA(
Expand All @@ -39,7 +45,7 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):

flowHMC_params = kwargs.get("flowHMC_params", {})
model = MaskedCouplingRQSpline(
self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1]
self.Prior.n_dim, self.num_layers, self.hidden_size, self.num_bins, self.rng_key_set[-1]
)
if len(flowHMC_params) > 0:
global_sampler = flowHMC(
Expand All @@ -57,7 +63,7 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):

self.Sampler = Sampler(
self.Prior.n_dim,
rng_key_set,
self.rng_key_set,
None, # type: ignore
local_sampler,
model,
Expand Down

0 comments on commit e15e705

Please sign in to comment.