From 27005b7a384d2221167ad03eed980e5aa5147815 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Thu, 9 Nov 2023 16:34:04 -0500 Subject: [PATCH 01/23] Adding flowHMC option --- example/GW150914_PV2_newglobal.py | 80 +++++++++++++++++++++++++++++++ src/jimgw/jim.py | 12 +++++ 2 files changed, 92 insertions(+) create mode 100644 example/GW150914_PV2_newglobal.py diff --git a/example/GW150914_PV2_newglobal.py b/example/GW150914_PV2_newglobal.py new file mode 100644 index 00000000..2456e767 --- /dev/null +++ b/example/GW150914_PV2_newglobal.py @@ -0,0 +1,80 @@ +import time +from jimgw.jim import Jim +from jimgw.detector import H1, L1 +from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD +from jimgw.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2 +from jimgw.prior import Uniform +import jax.numpy as jnp +import jax + + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 +gps = 1126259462.4 +start = gps - 2 +end = gps + 2 +fmin = 20.0 +fmax = 1024.0 + +ifos = ["H1", "L1"] + +H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +waveform = RippleIMRPhenomPv2(f_ref=20) +prior = Uniform( + xmin = [10, 0.125, 0, 0, 0, 0, 0, 0, 0., -0.05, 0., -1, 0., 0.,-1.], + xmax = [80., 1., jnp.pi, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1., 2000., 0.05, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1.], + naming = ["M_c", "q", "s1_theta", "s1_phi", "s1_mag", "s2_theta", "s2_phi", "s2_mag", "d_L", "t_c", "phase_c", "cos_iota", "psi", "ra", "sin_dec"], + transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2), + "s1_theta": ("s1_x", lambda params: jnp.sin(params['s1_theta'])*jnp.cos(params['s1_phi'])*params['s1_mag']), + "s1_phi": ("s1_y", lambda params: jnp.sin(params['s1_theta'])*jnp.sin(params['s1_phi'])*params['s1_mag']), + "s1_mag": ("s1_z", lambda params: jnp.cos(params['s1_theta'])*params['s1_mag']), + "s2_theta": ("s2_x", lambda params: jnp.sin(params['s2_theta'])*jnp.cos(params['s2_phi'])*params['s2_mag']), + "s2_phi": ("s2_y", lambda params: jnp.sin(params['s2_theta'])*jnp.sin(params['s2_phi'])*params['s2_mag']), + "s2_mag": ("s2_z", lambda params: jnp.cos(params['s2_theta'])*params['s2_mag']), + "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), + "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec +) +likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) + + +mass_matrix = jnp.eye(prior.n_dim) +mass_matrix = mass_matrix.at[1, 1].set(1e-3) +mass_matrix = mass_matrix.at[9, 9].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 3e-3} + + +jim = Jim( + likelihood, + prior, + n_loop_training=10, + n_loop_production=10, + n_local_steps=300, + n_global_steps=300, + n_chains=500, + n_epochs=300, + learning_rate=0.001, + max_samples = 60000, + momentum=0.9, + batch_size=30000, + use_global=True, + keep_quantile=0., + train_thinning=1, + output_thinning=30, + local_sampler_arg=local_sampler_arg, + num_layers = 6, + hidden_size = [32,32], + num_bins = 8 +) + +jim.maximize_likelihood([prior.xmin, prior.xmax]) +# initial_guess = jnp.array(jnp.load('initial.npz')['chain']) +jim.sample(jax.random.PRNGKey(42)) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 12aa89c1..203065b7 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -8,6 +8,7 @@ from jaxtyping import Array import jax import jax.numpy as jnp +from flowMC.sampler.flowHMC import flowHMC class Jim(object): """ @@ -31,12 +32,23 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): local_sampler = MALA(self.posterior, True, local_sampler_arg) # Remember to add routine to find automated mass matrix model = MaskedCouplingRQSpline(self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1]) + flowHMC_sampler = flowHMC( + self.posterior, + True, + model, + params={ + "step_size": 1e-4, + "n_leapfrog": 5, + "inverse_metric": jnp.ones(prior.n_dim), + }, + ) self.Sampler = Sampler( self.Prior.n_dim, rng_key_set, None, local_sampler, model, + global_sampler = flowHMC_sampler, **kwargs) From 2de19cff0d15a96a0b3542b8d75c40daa4f44898 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 10 Nov 2023 10:20:15 -0500 Subject: [PATCH 02/23] Unconstrained uniform should work now --- src/jimgw/prior.py | 45 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 4baf4298..ed595861 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -81,8 +81,8 @@ def add_name(self, x: Array, transform_name: bool = False, transform_value: bool class Uniform(Prior): - xmin: Array - xmax: Array + xmin: Union[float,Array] = 0. + xmax: Union[float,Array] = 1. def __init__(self, xmin: Union[float,Array], xmax: Union[float,Array], **kwargs): super().__init__(kwargs.get("naming"), kwargs.get("transforms")) @@ -111,4 +111,43 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: def log_prob(self, x: Array) -> Float: output = jnp.sum(jnp.where((x>=self.xmax) | (x<=self.xmin), jnp.zeros_like(x)-jnp.inf, jnp.zeros_like(x))) - return output + jnp.sum(jnp.log(1./(self.xmax-self.xmin))) + return output + jnp.sum(jnp.log(1./(self.xmax-self.xmin))) + +class Unconstrained_Uniform(Prior): + + xmin: float = 0. + xmax: float = 1. + + def __init__(self, xmin: float, xmax: float, **kwargs): + super().__init__(kwargs.get("naming"), kwargs.get("transforms")) + assert isinstance(xmin, float), "xmin must be a float" + assert isinstance(xmax, float), "xmax must be a float" + assert self.n_dim == 1, "Unconstrained_Uniform only works for 1D distributions" + self.xmax = xmax + self.xmin = xmin + self.transforms = {"y": ("x", lambda param: (self.xmax - self.xmin)/(1+jnp.exp(-param['x']))+self.xmin)} + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + """ + Sample from a uniform distribution. + + Parameters + ---------- + rng_key : jax.random.PRNGKey + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : Array + An array of shape (n_samples, n_dim) containing the samples. + + """ + samples = jax.random.uniform(rng_key, (n_samples,), minval=0, maxval=1) + samples = jnp.log(samples/(1-samples)) + return samples + + def log_prob(self, x: Array) -> Float: + y = 1. / 1 + jnp.exp(-x) + return (1/(self.xmax-self.xmin))*(1/(y-y*y)) \ No newline at end of file From cbdecfb01792f2db51f92cc371748891824fd021 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 10 Nov 2023 10:35:29 -0500 Subject: [PATCH 03/23] Changing uniform to univariate in favor of composite prior --- src/jimgw/prior.py | 57 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index ed595861..471df79e 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -1,3 +1,4 @@ +from abc import abstractmethod import jax import jax.numpy as jnp from flowMC.nfmodel.base import Distribution @@ -43,7 +44,9 @@ def make_lambda(name): if name in transforms: self.transforms[name] = transforms[name] else: - self.transforms[name] = (name, make_lambda(name)) # Without the function, the lambda will refer to the variable name instead of its value, which will make lambda reference the last value of the variable name + # Without the function, the lambda will refer to the variable name instead of its value, + # which will make lambda reference the last value of the variable name + self.transforms[name] = (name, make_lambda(name)) def transform(self, x: Array) -> Array: """ @@ -79,15 +82,23 @@ def add_name(self, x: Array, transform_name: bool = False, transform_value: bool value = x return dict(zip(naming,value)) + @abstractmethod + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + raise NotImplementedError + + @abstractmethod + def logpdf(self, x: dict) -> Float: + raise NotImplementedError + class Uniform(Prior): - xmin: Union[float,Array] = 0. - xmax: Union[float,Array] = 1. + xmin: float = 0. + xmax: float = 1. - def __init__(self, xmin: Union[float,Array], xmax: Union[float,Array], **kwargs): + def __init__(self, xmin: float, xmax: float, **kwargs): super().__init__(kwargs.get("naming"), kwargs.get("transforms")) - self.xmax = jnp.array(xmax) - self.xmin = jnp.array(xmin) + self.xmax = xmax + self.xmin = xmin def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: """ @@ -106,11 +117,12 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: An array of shape (n_samples, n_dim) containing the samples. """ - samples = jax.random.uniform(rng_key, (n_samples,self.n_dim), minval=self.xmin, maxval=self.xmax) - return samples # TODO: remember to cast this to a named array + samples = jax.random.uniform(rng_key, (n_samples,), minval=self.xmin, maxval=self.xmax) + return samples - def log_prob(self, x: Array) -> Float: - output = jnp.sum(jnp.where((x>=self.xmax) | (x<=self.xmin), jnp.zeros_like(x)-jnp.inf, jnp.zeros_like(x))) + def log_prob(self, x: dict) -> Float: + variable = x[self.naming[0]] + output = jnp.sum(jnp.where((variable>=self.xmax) | (variable<=self.xmin), jnp.zeros_like(variable)-jnp.inf, jnp.zeros_like(variable))) return output + jnp.sum(jnp.log(1./(self.xmax-self.xmin))) class Unconstrained_Uniform(Prior): @@ -150,4 +162,27 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: def log_prob(self, x: Array) -> Float: y = 1. / 1 + jnp.exp(-x) - return (1/(self.xmax-self.xmin))*(1/(y-y*y)) \ No newline at end of file + return (1/(self.xmax-self.xmin))*(1/(y-y*y)) + +class Composite(Prior): + + priors: list[Prior] = [] + + def __init__(self, priors: list[Prior], **kwargs): + naming = [] + transforms = {} + for prior in priors: + naming += prior.naming + transforms.update(prior.transforms) + self.priors = priors + self.naming = naming + self.transforms = transforms + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + for prior in self.priors: + rng_key, subkey = jax.random.split(rng_key) + prior.sample(subkey, n_samples) + + def log_prob(self, x: Array) -> Float: + for prior in self.priors: + prior.log_prob(x) \ No newline at end of file From 0afb2bd387611b2511c8fe08a2a184c0751b45e8 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 10 Nov 2023 10:51:06 -0500 Subject: [PATCH 04/23] Making prior test --- src/jimgw/prior.py | 19 ++++++++++--------- test/test_prior.py | 0 2 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 test/test_prior.py diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 471df79e..7ae9c0bb 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -82,11 +82,9 @@ def add_name(self, x: Array, transform_name: bool = False, transform_value: bool value = x return dict(zip(naming,value)) - @abstractmethod def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: raise NotImplementedError - @abstractmethod def logpdf(self, x: dict) -> Float: raise NotImplementedError @@ -95,8 +93,11 @@ class Uniform(Prior): xmin: float = 0. xmax: float = 1. - def __init__(self, xmin: float, xmax: float, **kwargs): - super().__init__(kwargs.get("naming"), kwargs.get("transforms")) + def __init__(self, xmin: float, xmax: float, naming: list[str], transforms: dict[tuple[str,Callable]] = {}): + super().__init__(naming, transforms) + assert isinstance(xmin, float), "xmin must be a float" + assert isinstance(xmax, float), "xmax must be a float" + assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin @@ -122,8 +123,8 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: def log_prob(self, x: dict) -> Float: variable = x[self.naming[0]] - output = jnp.sum(jnp.where((variable>=self.xmax) | (variable<=self.xmin), jnp.zeros_like(variable)-jnp.inf, jnp.zeros_like(variable))) - return output + jnp.sum(jnp.log(1./(self.xmax-self.xmin))) + output = jnp.where((variable>=self.xmax) | (variable<=self.xmin), jnp.zeros_like(variable)-jnp.inf, jnp.zeros_like(variable)) + return output + jnp.log(1./(self.xmax-self.xmin)) class Unconstrained_Uniform(Prior): @@ -134,7 +135,7 @@ def __init__(self, xmin: float, xmax: float, **kwargs): super().__init__(kwargs.get("naming"), kwargs.get("transforms")) assert isinstance(xmin, float), "xmin must be a float" assert isinstance(xmax, float), "xmax must be a float" - assert self.n_dim == 1, "Unconstrained_Uniform only works for 1D distributions" + assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin self.transforms = {"y": ("x", lambda param: (self.xmax - self.xmin)/(1+jnp.exp(-param['x']))+self.xmin)} @@ -162,11 +163,11 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: def log_prob(self, x: Array) -> Float: y = 1. / 1 + jnp.exp(-x) - return (1/(self.xmax-self.xmin))*(1/(y-y*y)) + return jnp.log((1/(self.xmax-self.xmin))*(1/(y-y*y))) class Composite(Prior): - priors: list[Prior] = [] + priors: list[Prior] = field(default_factory=list) def __init__(self, priors: list[Prior], **kwargs): naming = [] diff --git a/test/test_prior.py b/test/test_prior.py new file mode 100644 index 00000000..e69de29b From 56306edf1f2d9ac1891990a18ee5f16fb2b581e0 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 10 Nov 2023 10:54:12 -0500 Subject: [PATCH 05/23] format --- src/jimgw/prior.py | 77 ++++++++++++++++++++++++++++++---------------- test/test_prior.py | 3 ++ 2 files changed, 53 insertions(+), 27 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 7ae9c0bb..5c7be343 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -1,11 +1,11 @@ -from abc import abstractmethod import jax import jax.numpy as jnp from flowMC.nfmodel.base import Distribution from jaxtyping import Array, Float -from typing import Callable, Union +from typing import Callable from dataclasses import field + class Prior(Distribution): """ A thin wrapper build on top of flowMC distributions to do book keeping. @@ -17,13 +17,13 @@ class Prior(Distribution): """ naming: list[str] - transforms: dict[tuple[str,Callable]] = field(default_factory=dict) + transforms: dict[tuple[str, Callable]] = field(default_factory=dict) @property def n_dim(self): return len(self.naming) - - def __init__(self, naming: list[str], transforms: dict[tuple[str,Callable]] = {}): + + def __init__(self, naming: list[str], transforms: dict[tuple[str, Callable]] = {}): """ Parameters ---------- @@ -38,7 +38,7 @@ def __init__(self, naming: list[str], transforms: dict[tuple[str,Callable]] = {} self.transforms = {} def make_lambda(name): - return lambda x: x[name] + return lambda x: x[name] for name in naming: if name in transforms: @@ -46,7 +46,7 @@ def make_lambda(name): else: # Without the function, the lambda will refer to the variable name instead of its value, # which will make lambda reference the last value of the variable name - self.transforms[name] = (name, make_lambda(name)) + self.transforms[name] = (name, make_lambda(name)) def transform(self, x: Array) -> Array: """ @@ -62,12 +62,14 @@ def transform(self, x: Array) -> Array: x : dict A dictionary of parameters with the transforms applied. """ - output = self.add_name(x, transform_name = False, transform_value = False) + output = self.add_name(x, transform_name=False, transform_value=False) for i, (key, value) in enumerate(self.transforms.items()): x = x.at[i].set(value[1](output)) return x - def add_name(self, x: Array, transform_name: bool = False, transform_value: bool = False) -> dict: + def add_name( + self, x: Array, transform_name: bool = False, transform_value: bool = False + ) -> dict: """ Turn an array into a dictionary """ @@ -80,7 +82,7 @@ def add_name(self, x: Array, transform_name: bool = False, transform_value: bool value = x else: value = x - return dict(zip(naming,value)) + return dict(zip(naming, value)) def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: raise NotImplementedError @@ -88,19 +90,26 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: def logpdf(self, x: dict) -> Float: raise NotImplementedError + class Uniform(Prior): - xmin: float = 0. - xmax: float = 1. + xmin: float = 0.0 + xmax: float = 1.0 - def __init__(self, xmin: float, xmax: float, naming: list[str], transforms: dict[tuple[str,Callable]] = {}): + def __init__( + self, + xmin: float, + xmax: float, + naming: list[str], + transforms: dict[tuple[str, Callable]] = {}, + ): super().__init__(naming, transforms) assert isinstance(xmin, float), "xmin must be a float" assert isinstance(xmax, float), "xmax must be a float" assert self.n_dim == 1, "Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin - + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: """ Sample from a uniform distribution. @@ -116,20 +125,27 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: ------- samples : Array An array of shape (n_samples, n_dim) containing the samples. - + """ - samples = jax.random.uniform(rng_key, (n_samples,), minval=self.xmin, maxval=self.xmax) + samples = jax.random.uniform( + rng_key, (n_samples,), minval=self.xmin, maxval=self.xmax + ) return samples def log_prob(self, x: dict) -> Float: variable = x[self.naming[0]] - output = jnp.where((variable>=self.xmax) | (variable<=self.xmin), jnp.zeros_like(variable)-jnp.inf, jnp.zeros_like(variable)) - return output + jnp.log(1./(self.xmax-self.xmin)) + output = jnp.where( + (variable >= self.xmax) | (variable <= self.xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable), + ) + return output + jnp.log(1.0 / (self.xmax - self.xmin)) + class Unconstrained_Uniform(Prior): - xmin: float = 0. - xmax: float = 1. + xmin: float = 0.0 + xmax: float = 1.0 def __init__(self, xmin: float, xmax: float, **kwargs): super().__init__(kwargs.get("naming"), kwargs.get("transforms")) @@ -138,8 +154,14 @@ def __init__(self, xmin: float, xmax: float, **kwargs): assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin - self.transforms = {"y": ("x", lambda param: (self.xmax - self.xmin)/(1+jnp.exp(-param['x']))+self.xmin)} - + self.transforms = { + "y": ( + "x", + lambda param: (self.xmax - self.xmin) / (1 + jnp.exp(-param["x"])) + + self.xmin, + ) + } + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: """ Sample from a uniform distribution. @@ -155,15 +177,16 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: ------- samples : Array An array of shape (n_samples, n_dim) containing the samples. - + """ samples = jax.random.uniform(rng_key, (n_samples,), minval=0, maxval=1) - samples = jnp.log(samples/(1-samples)) + samples = jnp.log(samples / (1 - samples)) return samples def log_prob(self, x: Array) -> Float: - y = 1. / 1 + jnp.exp(-x) - return jnp.log((1/(self.xmax-self.xmin))*(1/(y-y*y))) + y = 1.0 / 1 + jnp.exp(-x) + return jnp.log((1 / (self.xmax - self.xmin)) * (1 / (y - y * y))) + class Composite(Prior): @@ -186,4 +209,4 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: def log_prob(self, x: Array) -> Float: for prior in self.priors: - prior.log_prob(x) \ No newline at end of file + prior.log_prob(x) diff --git a/test/test_prior.py b/test/test_prior.py index e69de29b..0b0cb3c0 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -0,0 +1,3 @@ +from jimgw.prior import Uniform, Unconstrained_Uniform, Composite + + From 960fc38c43a100a76264bea1d7943cf689432d11 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 10 Nov 2023 10:56:00 -0500 Subject: [PATCH 06/23] update --- src/jimgw/prior.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 5c7be343..e322e657 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -147,8 +147,14 @@ class Unconstrained_Uniform(Prior): xmin: float = 0.0 xmax: float = 1.0 - def __init__(self, xmin: float, xmax: float, **kwargs): - super().__init__(kwargs.get("naming"), kwargs.get("transforms")) + def __init__( + self, + xmin: float, + xmax: float, + naming: list[str], + transforms: dict[tuple[str, Callable]] = {}, + ): + super().__init__(naming, transforms) assert isinstance(xmin, float), "xmin must be a float" assert isinstance(xmax, float), "xmax must be a float" assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" From 82c1ffa378deb3b0a851ca0450ba484c07d171c8 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 10 Nov 2023 10:56:40 -0500 Subject: [PATCH 07/23] Add sphere --- src/jimgw/prior.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index e322e657..ea17b462 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -193,6 +193,16 @@ def log_prob(self, x: Array) -> Float: y = 1.0 / 1 + jnp.exp(-x) return jnp.log((1 / (self.xmax - self.xmin)) * (1 / (y - y * y))) +class Sphere(Prior): + + def __init__(self, naming: list[str], transforms: dict[tuple[str, Callable]] = {}): + super().__init__(naming, transforms) + + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + return super().sample(rng_key, n_samples) + + def log_prob(self, x: Array) -> Array: + return super().log_prob(x) class Composite(Prior): From 004c8634a44a1440f4767c80924e714acd302daf Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 10 Nov 2023 14:25:48 -0500 Subject: [PATCH 08/23] Let's see if keeping the naming is a bad idea --- src/jimgw/prior.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index ea17b462..6715fb4e 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -48,7 +48,7 @@ def make_lambda(name): # which will make lambda reference the last value of the variable name self.transforms[name] = (name, make_lambda(name)) - def transform(self, x: Array) -> Array: + def transform(self, x: dict) -> dict: """ Apply the transforms to the parameters. @@ -62,9 +62,9 @@ def transform(self, x: Array) -> Array: x : dict A dictionary of parameters with the transforms applied. """ - output = self.add_name(x, transform_name=False, transform_value=False) + # output = self.add_name(x, transform_name=False, transform_value=False) for i, (key, value) in enumerate(self.transforms.items()): - x = x.at[i].set(value[1](output)) + x[key] = value[1](x) return x def add_name( @@ -130,7 +130,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: samples = jax.random.uniform( rng_key, (n_samples,), minval=self.xmin, maxval=self.xmax ) - return samples + return self.add_name(samples[None]) def log_prob(self, x: dict) -> Float: variable = x[self.naming[0]] @@ -160,11 +160,14 @@ def __init__( assert self.n_dim == 1, "Unconstrained_Uniform needs to be 1D distributions" self.xmax = xmax self.xmin = xmin + local_transform = self.transforms + def new_transform(param): + result = (self.xmax - self.xmin) / (1 + jnp.exp(-param[self.naming[0]])) + self.xmin + return local_transform[self.naming[0]][1]({self.naming[0]:result}) self.transforms = { - "y": ( - "x", - lambda param: (self.xmax - self.xmin) / (1 + jnp.exp(-param["x"])) - + self.xmin, + self.naming[0]: ( + self.naming[0], + new_transform ) } @@ -187,7 +190,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: """ samples = jax.random.uniform(rng_key, (n_samples,), minval=0, maxval=1) samples = jnp.log(samples / (1 - samples)) - return samples + return self.add_name(samples[None]) def log_prob(self, x: Array) -> Float: y = 1.0 / 1 + jnp.exp(-x) From c587863152f26cc52f8ce554d4dddccfa66903fb Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 10 Nov 2023 14:55:51 -0500 Subject: [PATCH 09/23] Composite seems working fine --- src/jimgw/prior.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 6715fb4e..280ffbb5 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -84,7 +84,7 @@ def add_name( value = x return dict(zip(naming, value)) - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: raise NotImplementedError def logpdf(self, x: dict) -> Float: @@ -110,7 +110,7 @@ def __init__( self.xmax = xmax self.xmin = xmin - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: """ Sample from a uniform distribution. @@ -123,8 +123,8 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: Returns ------- - samples : Array - An array of shape (n_samples, n_dim) containing the samples. + samples : dict + Samples from the distribution. The keys are the names of the parameters. """ samples = jax.random.uniform( @@ -171,7 +171,7 @@ def new_transform(param): ) } - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: """ Sample from a uniform distribution. @@ -184,7 +184,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: Returns ------- - samples : Array + samples : An array of shape (n_samples, n_dim) containing the samples. """ @@ -192,8 +192,9 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: samples = jnp.log(samples / (1 - samples)) return self.add_name(samples[None]) - def log_prob(self, x: Array) -> Float: - y = 1.0 / 1 + jnp.exp(-x) + def log_prob(self, x: dict) -> Float: + variable = x[self.naming[0]] + y = 1.0 / (1 + jnp.exp(-variable)) return jnp.log((1 / (self.xmax - self.xmin)) * (1 / (y - y * y))) class Sphere(Prior): @@ -204,7 +205,7 @@ def __init__(self, naming: list[str], transforms: dict[tuple[str, Callable]] = { def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: return super().sample(rng_key, n_samples) - def log_prob(self, x: Array) -> Array: + def log_prob(self, x: dict) -> Float: return super().log_prob(x) class Composite(Prior): @@ -221,11 +222,15 @@ def __init__(self, priors: list[Prior], **kwargs): self.naming = naming self.transforms = transforms - def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: + def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: + output = {} for prior in self.priors: rng_key, subkey = jax.random.split(rng_key) - prior.sample(subkey, n_samples) + output.update(prior.sample(subkey, n_samples)) + return output - def log_prob(self, x: Array) -> Float: + def log_prob(self, x: dict) -> Float: + output = 0.0 for prior in self.priors: - prior.log_prob(x) + output += prior.log_prob(x) + return output \ No newline at end of file From d4d832af1fd1b6066ef99aafc62e056de08ebe28 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 10 Nov 2023 15:13:51 -0500 Subject: [PATCH 10/23] Composite seems working fine now --- src/jimgw/jim.py | 2 +- src/jimgw/prior.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 203065b7..00cfa049 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -72,7 +72,7 @@ def maximize_likelihood(self, bounds: tuple[Array,Array], set_nwalkers: int = 10 def posterior(self, params: Array, data: dict): named_params = self.Prior.add_name(params, transform_name=True, transform_value=True) - return self.Likelihood.evaluate(named_params, data) + self.Prior.log_prob(params) + return self.Likelihood.evaluate(named_params, data) + self.Prior.log_prob(named_params) def sample(self, key: jax.random.PRNGKey, initial_guess: Array = None): diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 280ffbb5..ed7dfd10 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -72,17 +72,21 @@ def add_name( ) -> dict: """ Turn an array into a dictionary + + Parameters + ---------- + x : Array + An array of parameters. Shape (n_dim, n_sample). """ if transform_name: naming = [value[0] for value in self.transforms.values()] else: naming = self.naming + x = dict(zip(naming, x)) if transform_value: - x = self.transform(x) - value = x + return self.transform(x) else: - value = x - return dict(zip(naming, value)) + return x def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: raise NotImplementedError From d31a83607d6d97293505b95754e8283f271af999 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sun, 12 Nov 2023 16:19:00 -0500 Subject: [PATCH 11/23] Something seems to work with flowHMC. Prior class seems pretty functional now --- example/GW150914_PV2_newglobal.py | 39 ++++++++++----------- src/jimgw/jim.py | 8 +++-- src/jimgw/prior.py | 56 +++++++++++++++++-------------- 3 files changed, 56 insertions(+), 47 deletions(-) diff --git a/example/GW150914_PV2_newglobal.py b/example/GW150914_PV2_newglobal.py index 2456e767..d995df98 100644 --- a/example/GW150914_PV2_newglobal.py +++ b/example/GW150914_PV2_newglobal.py @@ -3,7 +3,7 @@ from jimgw.detector import H1, L1 from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD from jimgw.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2 -from jimgw.prior import Uniform +from jimgw.prior import Uniform, Unconstrained_Uniform, Composite, Sphere import jax.numpy as jnp import jax @@ -29,33 +29,34 @@ L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) waveform = RippleIMRPhenomPv2(f_ref=20) -prior = Uniform( - xmin = [10, 0.125, 0, 0, 0, 0, 0, 0, 0., -0.05, 0., -1, 0., 0.,-1.], - xmax = [80., 1., jnp.pi, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1., 2000., 0.05, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1.], - naming = ["M_c", "q", "s1_theta", "s1_phi", "s1_mag", "s2_theta", "s2_phi", "s2_mag", "d_L", "t_c", "phase_c", "cos_iota", "psi", "ra", "sin_dec"], - transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2), - "s1_theta": ("s1_x", lambda params: jnp.sin(params['s1_theta'])*jnp.cos(params['s1_phi'])*params['s1_mag']), - "s1_phi": ("s1_y", lambda params: jnp.sin(params['s1_theta'])*jnp.sin(params['s1_phi'])*params['s1_mag']), - "s1_mag": ("s1_z", lambda params: jnp.cos(params['s1_theta'])*params['s1_mag']), - "s2_theta": ("s2_x", lambda params: jnp.sin(params['s2_theta'])*jnp.cos(params['s2_phi'])*params['s2_mag']), - "s2_phi": ("s2_y", lambda params: jnp.sin(params['s2_theta'])*jnp.sin(params['s2_phi'])*params['s2_mag']), - "s2_mag": ("s2_z", lambda params: jnp.cos(params['s2_theta'])*params['s2_mag']), - "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), - "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec -) + +Mc_prior = Unconstrained_Uniform(10., 80., naming=["M_c"]) +q_prior = Unconstrained_Uniform(0.125, 1., naming=["q"], transforms={"q": ("eta", lambda params: params['q']/(1+params['q'])**2)}) +s1_prior = Sphere("s1") +s2_prior = Sphere("s2") +dL_prior = Unconstrained_Uniform(0., 2000., naming=["d_L"]) +t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) +phase_c_prior = Unconstrained_Uniform(0., 2*jnp.pi, naming=["phase_c"]) +cos_iota_prior = Unconstrained_Uniform(-1., 1., naming=["cos_iota"], transforms={"cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi))}) +psi_prior = Unconstrained_Uniform(0., jnp.pi, naming=["psi"]) +ra_prior = Unconstrained_Uniform(0., 2*jnp.pi, naming=["ra"]) +sin_dec_prior = Unconstrained_Uniform(-1., 1., naming=["sin_dec"], transforms={"sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))}) + +prior = Composite([Mc_prior, q_prior, s1_prior, s2_prior, dL_prior, t_c_prior, phase_c_prior, cos_iota_prior, psi_prior, ra_prior, sin_dec_prior]) + likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) mass_matrix = jnp.eye(prior.n_dim) -mass_matrix = mass_matrix.at[1, 1].set(1e-3) -mass_matrix = mass_matrix.at[9, 9].set(1e-3) +# mass_matrix = mass_matrix.at[1, 1].set(1e-3) +# mass_matrix = mass_matrix.at[9, 9].set(1e-3) local_sampler_arg = {"step_size": mass_matrix * 3e-3} jim = Jim( likelihood, prior, - n_loop_training=10, + n_loop_training=50, n_loop_production=10, n_local_steps=300, n_global_steps=300, @@ -75,6 +76,6 @@ num_bins = 8 ) -jim.maximize_likelihood([prior.xmin, prior.xmax]) +# jim.maximize_likelihood([prior.xmin, prior.xmax]) # initial_guess = jnp.array(jnp.load('initial.npz')['chain']) jim.sample(jax.random.PRNGKey(42)) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 00cfa049..fab1a447 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -37,7 +37,7 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): True, model, params={ - "step_size": 1e-4, + "step_size": 1e-2, "n_leapfrog": 5, "inverse_metric": jnp.ones(prior.n_dim), }, @@ -71,13 +71,15 @@ def maximize_likelihood(self, bounds: tuple[Array,Array], set_nwalkers: int = 10 return best_fit def posterior(self, params: Array, data: dict): - named_params = self.Prior.add_name(params, transform_name=True, transform_value=True) - return self.Likelihood.evaluate(named_params, data) + self.Prior.log_prob(named_params) + prior_params = self.Prior.add_name(params.T) + prior = self.Prior.log_prob(prior_params) + return self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior def sample(self, key: jax.random.PRNGKey, initial_guess: Array = None): if initial_guess is None: initial_guess = self.Prior.sample(key, self.Sampler.n_chains) + initial_guess = jnp.stack([i for i in initial_guess.values()]).T self.Sampler.sample(initial_guess, None) def print_summary(self): diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index ed7dfd10..fc898fb3 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -62,14 +62,12 @@ def transform(self, x: dict) -> dict: x : dict A dictionary of parameters with the transforms applied. """ - # output = self.add_name(x, transform_name=False, transform_value=False) - for i, (key, value) in enumerate(self.transforms.items()): - x[key] = value[1](x) - return x - - def add_name( - self, x: Array, transform_name: bool = False, transform_value: bool = False - ) -> dict: + output = {} + for value in self.transforms.values(): + output[value[0]] = value[1](x) + return output + + def add_name(self, x: Array) -> dict: """ Turn an array into a dictionary @@ -78,15 +76,8 @@ def add_name( x : Array An array of parameters. Shape (n_dim, n_sample). """ - if transform_name: - naming = [value[0] for value in self.transforms.values()] - else: - naming = self.naming - x = dict(zip(naming, x)) - if transform_value: - return self.transform(x) - else: - return x + + return dict(zip(self.naming, x)) def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: raise NotImplementedError @@ -166,11 +157,11 @@ def __init__( self.xmin = xmin local_transform = self.transforms def new_transform(param): - result = (self.xmax - self.xmin) / (1 + jnp.exp(-param[self.naming[0]])) + self.xmin - return local_transform[self.naming[0]][1]({self.naming[0]:result}) + param[self.naming[0]] = (self.xmax - self.xmin) / (1 + jnp.exp(-param[self.naming[0]])) + self.xmin + return local_transform[self.naming[0]][1](param) self.transforms = { self.naming[0]: ( - self.naming[0], + local_transform[self.naming[0]][0], new_transform ) } @@ -203,14 +194,29 @@ def log_prob(self, x: dict) -> Float: class Sphere(Prior): - def __init__(self, naming: list[str], transforms: dict[tuple[str, Callable]] = {}): - super().__init__(naming, transforms) + """ + A prior on a sphere represented by Cartesian coordinates. + + Magnitude is sampled from a uniform distribution. + """ + + def __init__(self, naming: str): + self.naming = [f"{naming}_theta", f"{naming}_phi", f"{naming}_mag"] + self.transforms = { + self.naming[0]: (f"{naming}_x", lambda params: jnp.sin(params[self.naming[0]]) * jnp.cos(params[self.naming[1]]) * params[self.naming[2]]), + self.naming[1]: (f"{naming}_y", lambda params: jnp.sin(params[self.naming[0]]) * jnp.sin(params[self.naming[1]]) * params[self.naming[2]]), + self.naming[2]: (f"{naming}_z", lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]]), + } def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: - return super().sample(rng_key, n_samples) - + rng_keys = jax.random.split(rng_key,3) + theta = jax.random.uniform(rng_keys[0], (n_samples,), minval=0, maxval=2*jnp.pi) + phi = jnp.arccos(jax.random.uniform(rng_keys[1], (n_samples,), minval=-1., maxval=1.)) + mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) + return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) + def log_prob(self, x: dict) -> Float: - return super().log_prob(x) + return jnp.log(x[self.naming[2]]**2*jnp.sin(x[self.naming[1]])) class Composite(Prior): From 72e5cbf6ba996cf1f3d04d453c3d51cbfe19e608 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Wed, 22 Nov 2023 16:09:11 -0500 Subject: [PATCH 12/23] update Prior --- src/jimgw/prior.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index fc898fb3..d8f54473 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -137,10 +137,11 @@ def log_prob(self, x: dict) -> Float: return output + jnp.log(1.0 / (self.xmax - self.xmin)) -class Unconstrained_Uniform(Prior): +class UUniform(Prior): xmin: float = 0.0 xmax: float = 1.0 + to_range: Callable = lambda x: x def __init__( self, @@ -156,8 +157,9 @@ def __init__( self.xmax = xmax self.xmin = xmin local_transform = self.transforms + self.to_range = lambda x: (self.xmax - self.xmin) / (1 + jnp.exp(-x[self.naming[0]])) + self.xmin def new_transform(param): - param[self.naming[0]] = (self.xmax - self.xmin) / (1 + jnp.exp(-param[self.naming[0]])) + self.xmin + param[self.naming[0]] = self.to_range(param) return local_transform[self.naming[0]][1](param) self.transforms = { self.naming[0]: ( @@ -189,8 +191,13 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: def log_prob(self, x: dict) -> Float: variable = x[self.naming[0]] - y = 1.0 / (1 + jnp.exp(-variable)) - return jnp.log((1 / (self.xmax - self.xmin)) * (1 / (y - y * y))) + variable = (self.xmax - self.xmin) / (1 + jnp.exp(-variable)) + self.xmin + output = jnp.where( + (variable >= self.xmax) | (variable <= self.xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable) + jnp.log(1.0 / (self.xmax - self.xmin)), + ) + return output class Sphere(Prior): From 789837f0722f7596f041d2675d6f5546c3d51a10 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Wed, 22 Nov 2023 16:09:51 -0500 Subject: [PATCH 13/23] Revert weird VScode changes --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index d8f54473..1e8454c9 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -137,7 +137,7 @@ def log_prob(self, x: dict) -> Float: return output + jnp.log(1.0 / (self.xmax - self.xmin)) -class UUniform(Prior): +class Unconstrained_Uniform(Prior): xmin: float = 0.0 xmax: float = 1.0 From 677623e2f5beb3af8b547ab258d2502ef64eca38 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Wed, 22 Nov 2023 16:22:52 -0500 Subject: [PATCH 14/23] Fix bug in spehre prior --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 1e8454c9..f28ea5fc 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -223,7 +223,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) def log_prob(self, x: dict) -> Float: - return jnp.log(x[self.naming[2]]**2*jnp.sin(x[self.naming[1]])) + return jnp.log(x[self.naming[2]]**2*jnp.sin(jnp.arccos(x[self.naming[1]]))) class Composite(Prior): From 46b017a46cc2dce18828cb84eabd4b2570e61665 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Wed, 22 Nov 2023 17:31:01 -0500 Subject: [PATCH 15/23] unfix fake bug in Sphere prior --- src/jimgw/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index f28ea5fc..1e8454c9 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -223,7 +223,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) def log_prob(self, x: dict) -> Float: - return jnp.log(x[self.naming[2]]**2*jnp.sin(jnp.arccos(x[self.naming[1]]))) + return jnp.log(x[self.naming[2]]**2*jnp.sin(x[self.naming[1]])) class Composite(Prior): From a9b3efcc11809f6ac5370757535ee8e5e3267949 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Thu, 23 Nov 2023 19:14:48 -0500 Subject: [PATCH 16/23] Add option to use flowHMC --- src/jimgw/jim.py | 26 ++++++++++++++++---------- test/test_prior.py | 4 +--- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index fab1a447..4aece5d0 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -31,17 +31,23 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): local_sampler = MALA(self.posterior, True, local_sampler_arg) # Remember to add routine to find automated mass matrix + flowHMC_params = kwargs.get("flowHMC_params", {}) + if len(flowHMC_params) > 0: + flowHMC_sampler = flowHMC( + self.posterior, + True, + model, + params={ + "step_size": flowHMC_params["step_size"], + "n_leapfrog": flowHMC_params["n_leapfrog"], + "condition_matrix": flowHMC_params["condition_matrix"], + }, + ) + else: + flowHMC_sampler = None + model = MaskedCouplingRQSpline(self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1]) - flowHMC_sampler = flowHMC( - self.posterior, - True, - model, - params={ - "step_size": 1e-2, - "n_leapfrog": 5, - "inverse_metric": jnp.ones(prior.n_dim), - }, - ) + self.Sampler = Sampler( self.Prior.n_dim, rng_key_set, diff --git a/test/test_prior.py b/test/test_prior.py index 0b0cb3c0..ab99431c 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -1,3 +1 @@ -from jimgw.prior import Uniform, Unconstrained_Uniform, Composite - - +from jimgw.prior import Uniform, UUniform, Composite From abe4569c3f9fc37948791771cbbdb16ffe5367b8 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 24 Nov 2023 13:04:56 -0500 Subject: [PATCH 17/23] Fix Unconstrained uniform prior --- src/jimgw/prior.py | 59 +++++++++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 1e8454c9..5b6fa21f 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -157,15 +157,17 @@ def __init__( self.xmax = xmax self.xmin = xmin local_transform = self.transforms - self.to_range = lambda x: (self.xmax - self.xmin) / (1 + jnp.exp(-x[self.naming[0]])) + self.xmin + self.to_range = ( + lambda x: (self.xmax - self.xmin) / (1 + jnp.exp(-x[self.naming[0]])) + + self.xmin + ) + def new_transform(param): param[self.naming[0]] = self.to_range(param) return local_transform[self.naming[0]][1](param) + self.transforms = { - self.naming[0]: ( - local_transform[self.naming[0]][0], - new_transform - ) + self.naming[0]: (local_transform[self.naming[0]][0], new_transform) } def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: @@ -181,7 +183,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: Returns ------- - samples : + samples : An array of shape (n_samples, n_dim) containing the samples. """ @@ -191,13 +193,8 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: def log_prob(self, x: dict) -> Float: variable = x[self.naming[0]] - variable = (self.xmax - self.xmin) / (1 + jnp.exp(-variable)) + self.xmin - output = jnp.where( - (variable >= self.xmax) | (variable <= self.xmin), - jnp.zeros_like(variable) - jnp.inf, - jnp.zeros_like(variable) + jnp.log(1.0 / (self.xmax - self.xmin)), - ) - return output + return jnp.log(jnp.exp(-variable)/(1 + jnp.exp(-variable))**2) + class Sphere(Prior): @@ -210,20 +207,38 @@ class Sphere(Prior): def __init__(self, naming: str): self.naming = [f"{naming}_theta", f"{naming}_phi", f"{naming}_mag"] self.transforms = { - self.naming[0]: (f"{naming}_x", lambda params: jnp.sin(params[self.naming[0]]) * jnp.cos(params[self.naming[1]]) * params[self.naming[2]]), - self.naming[1]: (f"{naming}_y", lambda params: jnp.sin(params[self.naming[0]]) * jnp.sin(params[self.naming[1]]) * params[self.naming[2]]), - self.naming[2]: (f"{naming}_z", lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]]), + self.naming[0]: ( + f"{naming}_x", + lambda params: jnp.sin(params[self.naming[0]]) + * jnp.cos(params[self.naming[1]]) + * params[self.naming[2]], + ), + self.naming[1]: ( + f"{naming}_y", + lambda params: jnp.sin(params[self.naming[0]]) + * jnp.sin(params[self.naming[1]]) + * params[self.naming[2]], + ), + self.naming[2]: ( + f"{naming}_z", + lambda params: jnp.cos(params[self.naming[0]]) * params[self.naming[2]], + ), } def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: - rng_keys = jax.random.split(rng_key,3) - theta = jax.random.uniform(rng_keys[0], (n_samples,), minval=0, maxval=2*jnp.pi) - phi = jnp.arccos(jax.random.uniform(rng_keys[1], (n_samples,), minval=-1., maxval=1.)) + rng_keys = jax.random.split(rng_key, 3) + theta = jax.random.uniform( + rng_keys[0], (n_samples,), minval=0, maxval=2 * jnp.pi + ) + phi = jnp.arccos( + jax.random.uniform(rng_keys[1], (n_samples,), minval=-1.0, maxval=1.0) + ) mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) - + def log_prob(self, x: dict) -> Float: - return jnp.log(x[self.naming[2]]**2*jnp.sin(x[self.naming[1]])) + return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[1]])) + class Composite(Prior): @@ -250,4 +265,4 @@ def log_prob(self, x: dict) -> Float: output = 0.0 for prior in self.priors: output += prior.log_prob(x) - return output \ No newline at end of file + return output From cfdc30cf1b1f80c55c0297249d35346930e87b60 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 24 Nov 2023 13:05:19 -0500 Subject: [PATCH 18/23] Add experimental flowHMC support --- src/jimgw/jim.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 4aece5d0..41dd8ef9 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -32,8 +32,9 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): local_sampler = MALA(self.posterior, True, local_sampler_arg) # Remember to add routine to find automated mass matrix flowHMC_params = kwargs.get("flowHMC_params", {}) + model = MaskedCouplingRQSpline(self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1]) if len(flowHMC_params) > 0: - flowHMC_sampler = flowHMC( + global_sampler = flowHMC( self.posterior, True, model, @@ -44,9 +45,8 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): }, ) else: - flowHMC_sampler = None + global_sampler = None - model = MaskedCouplingRQSpline(self.Prior.n_dim, num_layers, hidden_size, num_bins, rng_key_set[-1]) self.Sampler = Sampler( self.Prior.n_dim, @@ -54,7 +54,7 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs): None, local_sampler, model, - global_sampler = flowHMC_sampler, + global_sampler = global_sampler, **kwargs) From e1ba101eb3adf5aabc18da74672189bee5ee34a7 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Fri, 24 Nov 2023 13:11:18 -0500 Subject: [PATCH 19/23] Unconstrained_uniform tested on GW150914, seems working now --- example/GW150914.py | 79 ++++++++++++++++++++++++++++++++------------- 1 file changed, 57 insertions(+), 22 deletions(-) diff --git a/example/GW150914.py b/example/GW150914.py index 2e73ad58..38eae6c4 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -3,7 +3,7 @@ from jimgw.detector import H1, L1 from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD from jimgw.waveform import RippleIMRPhenomD -from jimgw.prior import Uniform +from jimgw.prior import Unconstrained_Uniform, Composite import jax.numpy as jnp import jax @@ -27,28 +27,63 @@ H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -prior = Uniform( - xmin=[10, 0.125, -1.0, -1.0, 0.0, -0.05, 0.0, -1, 0.0, 0.0, -1.0], - xmax=[80.0, 1.0, 1.0, 1.0, 2000.0, 0.05, 2 * jnp.pi, 1.0, jnp.pi, 2 * jnp.pi, 1.0], - naming=[ - "M_c", - "q", - "s1_z", - "s2_z", - "d_L", - "t_c", - "phase_c", - "cos_iota", - "psi", - "ra", - "sin_dec", - ], - transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2), - "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), - "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec +Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) +q_prior = Unconstrained_Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, +) +s1z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s1_z"]) +s2z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s2_z"]) +dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) +t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) +phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +cos_iota_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) +psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +sin_dec_prior = Unconstrained_Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) + +prior = Composite( + [ + Mc_prior, + q_prior, + s1z_prior, + s2z_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ] ) likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) -# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=[prior.xmin, prior.xmax], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) mass_matrix = jnp.eye(11) @@ -76,5 +111,5 @@ local_sampler_arg=local_sampler_arg, ) -jim.maximize_likelihood([prior.xmin, prior.xmax]) +# jim.maximize_likelihood([prior.xmin, prior.xmax]) jim.sample(jax.random.PRNGKey(42)) From d011acaef102c478dfb102c55c886b191ec2345b Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 27 Nov 2023 12:19:24 -0500 Subject: [PATCH 20/23] Black formatting --- example/GW150914.py | 8 +++++++- src/jimgw/prior.py | 2 +- test/test_prior.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/example/GW150914.py b/example/GW150914.py index 38eae6c4..a9c1c9c8 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -83,7 +83,13 @@ sin_dec_prior, ] ) -likelihood = TransientLikelihoodFD([H1, L1], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) +likelihood = TransientLikelihoodFD( + [H1, L1], + waveform=RippleIMRPhenomD(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, +) mass_matrix = jnp.eye(11) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 5b6fa21f..8a32e2c3 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -193,7 +193,7 @@ def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: def log_prob(self, x: dict) -> Float: variable = x[self.naming[0]] - return jnp.log(jnp.exp(-variable)/(1 + jnp.exp(-variable))**2) + return jnp.log(jnp.exp(-variable) / (1 + jnp.exp(-variable)) ** 2) class Sphere(Prior): diff --git a/test/test_prior.py b/test/test_prior.py index ab99431c..d4c5be59 100644 --- a/test/test_prior.py +++ b/test/test_prior.py @@ -1 +1 @@ -from jimgw.prior import Uniform, UUniform, Composite +from jimgw.prior import Uniform, Unconstrained_Uniform, Composite From 7227fdcbde46e179f332e55721c9873599f205dc Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Mon, 27 Nov 2023 15:20:57 -0500 Subject: [PATCH 21/23] Fixing convention bugs in sphere prior --- src/jimgw/prior.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 8a32e2c3..91befa2a 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -228,16 +228,16 @@ def __init__(self, naming: str): def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: rng_keys = jax.random.split(rng_key, 3) theta = jax.random.uniform( - rng_keys[0], (n_samples,), minval=0, maxval=2 * jnp.pi + rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0 ) phi = jnp.arccos( - jax.random.uniform(rng_keys[1], (n_samples,), minval=-1.0, maxval=1.0) + jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2*jnp.pi) ) mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) def log_prob(self, x: dict) -> Float: - return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[1]])) + return jnp.log(x[self.naming[2]] ** 2 * jnp.sin(x[self.naming[0]])) class Composite(Prior): From 09c0d12e121dfed10085e194b84e46b0834661a4 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 28 Nov 2023 14:07:04 -0500 Subject: [PATCH 22/23] I think this version of PV2 works now --- example/GW150914_PV2.py | 86 ++++++++++++++++++++++++++++++----------- src/jimgw/prior.py | 18 ++++----- 2 files changed, 73 insertions(+), 31 deletions(-) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index 6dc91e79..140af05b 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -2,8 +2,8 @@ from jimgw.jim import Jim from jimgw.detector import H1, L1 from jimgw.likelihood import HeterodynedTransientLikelihoodFD, TransientLikelihoodFD -from jimgw.waveform import RippleIMRPhenomD, RippleIMRPhenomPv2 -from jimgw.prior import Uniform +from jimgw.waveform import RippleIMRPhenomPv2 +from jimgw.prior import Uniform, Composite, Sphere import jax.numpy as jnp import jax @@ -28,19 +28,66 @@ L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) waveform = RippleIMRPhenomPv2(f_ref=20) -prior = Uniform( - xmin = [10, 0.125, 0, 0, 0, 0, 0, 0, 0., -0.05, 0., -1, 0., 0.,-1.], - xmax = [80., 1., jnp.pi, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1., 2000., 0.05, 2*jnp.pi, 1., jnp.pi, 2*jnp.pi, 1.], - naming = ["M_c", "q", "s1_theta", "s1_phi", "s1_mag", "s2_theta", "s2_phi", "s2_mag", "d_L", "t_c", "phase_c", "cos_iota", "psi", "ra", "sin_dec"], - transforms = {"q": ("eta", lambda params: params['q']/(1+params['q'])**2), - "s1_theta": ("s1_x", lambda params: jnp.sin(params['s1_theta'])*jnp.cos(params['s1_phi'])*params['s1_mag']), - "s1_phi": ("s1_y", lambda params: jnp.sin(params['s1_theta'])*jnp.sin(params['s1_phi'])*params['s1_mag']), - "s1_mag": ("s1_z", lambda params: jnp.cos(params['s1_theta'])*params['s1_mag']), - "s2_theta": ("s2_x", lambda params: jnp.sin(params['s2_theta'])*jnp.cos(params['s2_phi'])*params['s2_mag']), - "s2_phi": ("s2_y", lambda params: jnp.sin(params['s2_theta'])*jnp.sin(params['s2_phi'])*params['s2_mag']), - "s2_mag": ("s2_z", lambda params: jnp.cos(params['s2_theta'])*params['s2_mag']), - "cos_iota": ("iota",lambda params: jnp.arccos(jnp.arcsin(jnp.sin(params['cos_iota']/2*jnp.pi))*2/jnp.pi)), - "sin_dec": ("dec",lambda params: jnp.arcsin(jnp.arcsin(jnp.sin(params['sin_dec']/2*jnp.pi))*2/jnp.pi))} # sin and arcsin are periodize cos_iota and sin_dec + +########################################### +########## Set up priors ################## +########################################### + +Mc_prior = Uniform(10.0, 80.0, naming=["M_c"]) +q_prior = Uniform( + 0.125, + 1.0, + naming=["q"], + transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, +) +s1_prior = Sphere(naming="s1") +s2_prior = Sphere(naming="s2") +dL_prior = Uniform(0.0, 2000.0, naming=["d_L"]) +t_c_prior = Uniform(-0.05, 0.05, naming=["t_c"]) +phase_c_prior = Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) +cos_iota_prior = Uniform( + -1.0, + 1.0, + naming=["cos_iota"], + transforms={ + "cos_iota": ( + "iota", + lambda params: jnp.arccos( + jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) +psi_prior = Uniform(0.0, jnp.pi, naming=["psi"]) +ra_prior = Uniform(0.0, 2 * jnp.pi, naming=["ra"]) +sin_dec_prior = Uniform( + -1.0, + 1.0, + naming=["sin_dec"], + transforms={ + "sin_dec": ( + "dec", + lambda params: jnp.arcsin( + jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi + ), + ) + }, +) + +prior = Composite( + [ + Mc_prior, + q_prior, + s1_prior, + s2_prior, + dL_prior, + t_c_prior, + phase_c_prior, + cos_iota_prior, + psi_prior, + ra_prior, + sin_dec_prior, + ], ) likelihood = TransientLikelihoodFD([H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) # likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=[prior.xmin, prior.xmax], waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, post_trigger_duration=2) @@ -54,7 +101,7 @@ jim = Jim( likelihood, prior, - n_loop_training=200, + n_loop_training=100, n_loop_production=10, n_local_steps=300, n_global_steps=300, @@ -63,17 +110,12 @@ learning_rate=0.001, max_samples = 60000, momentum=0.9, - batch_size=30000, + batch_size=60000, use_global=True, keep_quantile=0., train_thinning=1, output_thinning=30, local_sampler_arg=local_sampler_arg, - num_layers = 6, - hidden_size = [32,32], - num_bins = 8 ) -jim.maximize_likelihood([prior.xmin, prior.xmax]) -# initial_guess = jnp.array(jnp.load('initial.npz')['chain']) jim.sample(jax.random.PRNGKey(42)) diff --git a/src/jimgw/prior.py b/src/jimgw/prior.py index 91befa2a..2137b018 100644 --- a/src/jimgw/prior.py +++ b/src/jimgw/prior.py @@ -227,12 +227,12 @@ def __init__(self, naming: str): def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> Array: rng_keys = jax.random.split(rng_key, 3) - theta = jax.random.uniform( - rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0 - ) - phi = jnp.arccos( - jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2*jnp.pi) + theta = jnp.arccos( + jax.random.uniform( + rng_keys[0], (n_samples,), minval=-1.0, maxval=1.0 + ) ) + phi = jax.random.uniform(rng_keys[1], (n_samples,), minval=0, maxval=2*jnp.pi) mag = jax.random.uniform(rng_keys[2], (n_samples,), minval=0, maxval=1) return self.add_name(jnp.stack([theta, phi, mag], axis=1).T) @@ -244,15 +244,15 @@ class Composite(Prior): priors: list[Prior] = field(default_factory=list) - def __init__(self, priors: list[Prior], **kwargs): + def __init__(self, priors: list[Prior], transforms: dict[tuple[str, Callable]] = {}): naming = [] - transforms = {} + self.transforms = {} for prior in priors: naming += prior.naming - transforms.update(prior.transforms) + self.transforms.update(prior.transforms) self.priors = priors self.naming = naming - self.transforms = transforms + self.transforms.update(transforms) def sample(self, rng_key: jax.random.PRNGKey, n_samples: int) -> dict: output = {} From 627cc8554ac612e203e2ba515941feb918240213 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Tue, 28 Nov 2023 14:38:20 -0500 Subject: [PATCH 23/23] Update action --- .github/workflows/python-package.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 56ce5dd7..44d9bb6b 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -29,6 +29,7 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + python -m pip install . - name: Test with pytest run: | pytest