diff --git a/gpax/models/hskgp.py b/gpax/models/hskgp.py index ec92016..25c7a90 100644 --- a/gpax/models/hskgp.py +++ b/gpax/models/hskgp.py @@ -16,6 +16,7 @@ from . import ExactGP from ..kernels import get_kernel +from ..utils import _set_noise_kernel_fn kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] @@ -55,6 +56,7 @@ def __init__( mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, lengthscale_prior_dist: Optional[dist.Distribution] = None, noise_mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, noise_mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, @@ -62,9 +64,12 @@ def __init__( ) -> None: args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, None, None, lengthscale_prior_dist) super(VarNoiseGP, self).__init__(*args) - self.noise_kernel = get_kernel(noise_kernel) + noise_kernel_ = get_kernel(noise_kernel) + self.noise_kernel = _set_noise_kernel_fn(noise_kernel_) if isinstance(noise_kernel, str) else noise_kernel_ + self.noise_mean_fn = noise_mean_fn self.noise_mean_fn_prior = noise_mean_fn_prior + self.noise_kernel_prior = noise_kernel_prior self.noise_lengthscale_prior_dist = noise_lengthscale_prior_dist def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: @@ -74,7 +79,10 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: noise_f_loc = jnp.zeros(X.shape[0]) # Sample noise kernel parameters - noise_kernel_params = self._sample_noise_kernel_params() + if self.noise_kernel_prior: + noise_kernel_params = self.noise_kernel_prior() + else: + noise_kernel_params = self._sample_noise_kernel_params() # Add noise prior mean function (if any) if self.noise_mean_fn is not None: args = [X] @@ -120,7 +128,7 @@ def _sample_noise_kernel_params(self) -> Dict[str, jnp.ndarray]: noise_length_dist = dist.LogNormal(0, 1) noise_scale = numpyro.sample("k_noise_scale", dist.LogNormal(0, 1)) noise_length = numpyro.sample("k_noise_length", noise_length_dist) - return {"k_length": noise_length, "k_scale": noise_scale} + return {"k_noise_length": noise_length, "k_noise_scale": noise_scale} def get_mvn_posterior( self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], *args, **kwargs @@ -148,14 +156,8 @@ def get_mvn_posterior( # Noise GP part # Compute noise kernel matrices - k_pX_noise = self.noise_kernel( - X_new, self.X_train, - {"k_length": params["k_noise_length"], "k_scale": params["k_noise_scale"]}, - jitter=0.0) - k_XX_noise = self.noise_kernel( - self.X_train, self.X_train, - {"k_length": params["k_noise_length"], "k_scale": params["k_noise_scale"]}, - 0, **kwargs) + k_pX_noise = self.noise_kernel(X_new, self.X_train, params, jitter=0.0) + k_XX_noise = self.noise_kernel(self.X_train, self.X_train, params, 0, **kwargs) # Compute noise predictive mean log_var_residual = params["log_var"].copy() if self.noise_mean_fn is not None: diff --git a/gpax/utils/__init__.py b/gpax/utils/__init__.py index 81a0092..245c13a 100644 --- a/gpax/utils/__init__.py +++ b/gpax/utils/__init__.py @@ -1,2 +1,3 @@ from .utils import * -from .priors import * \ No newline at end of file +from .priors import * +from .priors import _set_noise_kernel_fn diff --git a/gpax/utils/priors.py b/gpax/utils/priors.py index b24d8ca..ff105b1 100644 --- a/gpax/utils/priors.py +++ b/gpax/utils/priors.py @@ -16,6 +16,8 @@ import jax import jax.numpy as jnp +from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt + def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0): """ @@ -227,12 +229,45 @@ def set_kernel_fn(func: Callable, transformed_code += custom_code - local_namespace = {"jit": jax.jit} + local_namespace = {"jit": jax.jit} exec(transformed_code, globals(), local_namespace) return local_namespace[func.__name__] +def _set_noise_kernel_fn(func: Callable) -> Callable: + """ + Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses. + + Args: + func (Callable): Original function. + + Returns: + Callable: Modified function. + """ + + # Get the source code of the function + source = inspect.getsource(func) + + # Split the source into decorators, definition, and body + decorators_and_def, body = source.split("\n", 1) + + # Replace all occurrences of params["k with params["k_noise in the body + modified_body = re.sub(r'params\["k', 'params["k_noise', body) + + # Combine decorators, definition, and modified body + modified_source = f"{decorators_and_def}\n{modified_body}" + + # Define local namespace including the jit decorator + local_namespace = {"jit": jax.jit} + + # Execute the modified source to redefine the function in the provided namespace + exec(modified_source, globals(), local_namespace) + + # Return the modified function + return local_namespace[func.__name__] + + def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable: """ Generates a function that, when invoked, samples from normal or log-normal distributions diff --git a/tests/test_utilpriors.py b/tests/test_utilpriors.py index 28a4483..60ee3c1 100644 --- a/tests/test_utilpriors.py +++ b/tests/test_utilpriors.py @@ -8,7 +8,8 @@ from gpax.utils import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, place_lognormal_prior from gpax.utils import uniform_dist, normal_dist, halfnormal_dist, lognormal_dist, gamma_dist -from gpax.utils import set_fn, set_kernel_fn, auto_lognormal_priors, auto_normal_priors, auto_lognormal_kernel_priors, auto_normal_kernel_priors, auto_priors +from gpax.utils import auto_lognormal_priors, auto_normal_priors, auto_lognormal_kernel_priors, auto_normal_kernel_priors, auto_priors +from gpax.utils import set_fn, set_kernel_fn, _set_noise_kernel_fn def linear_kernel_test(X, Z, k_scale): @@ -233,3 +234,13 @@ def test_auto_normal_kernel_priors(autopriors): with numpyro.handlers.trace() as tr: priors_fn() assert_('k_scale' in tr) + + +def test_set_noise_kernel_fn(): + from gpax.kernels import RBFKernel + + X = jnp.array([[1, 2], [3, 4], [5, 6]]) + params_i = {"k_length": jnp.array([1.0]), "k_scale": jnp.array(1.0)} + params = {"k_noise_length": jnp.array([1.0]), "k_noise_scale": jnp.array(1.0)} + noise_rbf = _set_noise_kernel_fn(RBFKernel) + assert_(jnp.array_equal(noise_rbf(X, X, params), RBFKernel(X, X, params_i)))