Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to specify custom noise kernels #48

Merged
merged 5 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions gpax/models/hskgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from . import ExactGP
from ..kernels import get_kernel
from ..utils import _set_noise_kernel_fn

kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]

Expand Down Expand Up @@ -55,16 +56,20 @@ def __init__(
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
lengthscale_prior_dist: Optional[dist.Distribution] = None,
noise_mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
noise_mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_lengthscale_prior_dist: Optional[dist.Distribution] = None
) -> None:
args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, None, None, lengthscale_prior_dist)
super(VarNoiseGP, self).__init__(*args)
self.noise_kernel = get_kernel(noise_kernel)
noise_kernel_ = get_kernel(noise_kernel)
self.noise_kernel = _set_noise_kernel_fn(noise_kernel_) if isinstance(noise_kernel, str) else noise_kernel_

self.noise_mean_fn = noise_mean_fn
self.noise_mean_fn_prior = noise_mean_fn_prior
self.noise_kernel_prior = noise_kernel_prior
self.noise_lengthscale_prior_dist = noise_lengthscale_prior_dist

def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
Expand All @@ -74,7 +79,10 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
noise_f_loc = jnp.zeros(X.shape[0])

# Sample noise kernel parameters
noise_kernel_params = self._sample_noise_kernel_params()
if self.noise_kernel_prior:
noise_kernel_params = self.noise_kernel_prior()
else:
noise_kernel_params = self._sample_noise_kernel_params()
# Add noise prior mean function (if any)
if self.noise_mean_fn is not None:
args = [X]
Expand Down Expand Up @@ -120,7 +128,7 @@ def _sample_noise_kernel_params(self) -> Dict[str, jnp.ndarray]:
noise_length_dist = dist.LogNormal(0, 1)
noise_scale = numpyro.sample("k_noise_scale", dist.LogNormal(0, 1))
noise_length = numpyro.sample("k_noise_length", noise_length_dist)
return {"k_length": noise_length, "k_scale": noise_scale}
return {"k_noise_length": noise_length, "k_noise_scale": noise_scale}

def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], *args, **kwargs
Expand Down Expand Up @@ -148,14 +156,8 @@ def get_mvn_posterior(

# Noise GP part
# Compute noise kernel matrices
k_pX_noise = self.noise_kernel(
X_new, self.X_train,
{"k_length": params["k_noise_length"], "k_scale": params["k_noise_scale"]},
jitter=0.0)
k_XX_noise = self.noise_kernel(
self.X_train, self.X_train,
{"k_length": params["k_noise_length"], "k_scale": params["k_noise_scale"]},
0, **kwargs)
k_pX_noise = self.noise_kernel(X_new, self.X_train, params, jitter=0.0)
k_XX_noise = self.noise_kernel(self.X_train, self.X_train, params, 0, **kwargs)
# Compute noise predictive mean
log_var_residual = params["log_var"].copy()
if self.noise_mean_fn is not None:
Expand Down
3 changes: 2 additions & 1 deletion gpax/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .utils import *
from .priors import *
from .priors import *
from .priors import _set_noise_kernel_fn
37 changes: 36 additions & 1 deletion gpax/utils/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import jax
import jax.numpy as jnp

from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt


def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
"""
Expand Down Expand Up @@ -227,12 +229,45 @@ def set_kernel_fn(func: Callable,

transformed_code += custom_code

local_namespace = {"jit": jax.jit}
local_namespace = {"jit": jax.jit}
exec(transformed_code, globals(), local_namespace)

return local_namespace[func.__name__]


def _set_noise_kernel_fn(func: Callable) -> Callable:
"""
Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses.

Args:
func (Callable): Original function.

Returns:
Callable: Modified function.
"""

# Get the source code of the function
source = inspect.getsource(func)

# Split the source into decorators, definition, and body
decorators_and_def, body = source.split("\n", 1)

# Replace all occurrences of params["k with params["k_noise in the body
modified_body = re.sub(r'params\["k', 'params["k_noise', body)

# Combine decorators, definition, and modified body
modified_source = f"{decorators_and_def}\n{modified_body}"

# Define local namespace including the jit decorator
local_namespace = {"jit": jax.jit}

# Execute the modified source to redefine the function in the provided namespace
exec(modified_source, globals(), local_namespace)

# Return the modified function
return local_namespace[func.__name__]


def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable:
"""
Generates a function that, when invoked, samples from normal or log-normal distributions
Expand Down
13 changes: 12 additions & 1 deletion tests/test_utilpriors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from gpax.utils import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, place_lognormal_prior
from gpax.utils import uniform_dist, normal_dist, halfnormal_dist, lognormal_dist, gamma_dist
from gpax.utils import set_fn, set_kernel_fn, auto_lognormal_priors, auto_normal_priors, auto_lognormal_kernel_priors, auto_normal_kernel_priors, auto_priors
from gpax.utils import auto_lognormal_priors, auto_normal_priors, auto_lognormal_kernel_priors, auto_normal_kernel_priors, auto_priors
from gpax.utils import set_fn, set_kernel_fn, _set_noise_kernel_fn


def linear_kernel_test(X, Z, k_scale):
Expand Down Expand Up @@ -233,3 +234,13 @@ def test_auto_normal_kernel_priors(autopriors):
with numpyro.handlers.trace() as tr:
priors_fn()
assert_('k_scale' in tr)


def test_set_noise_kernel_fn():
from gpax.kernels import RBFKernel

X = jnp.array([[1, 2], [3, 4], [5, 6]])
params_i = {"k_length": jnp.array([1.0]), "k_scale": jnp.array(1.0)}
params = {"k_noise_length": jnp.array([1.0]), "k_noise_scale": jnp.array(1.0)}
noise_rbf = _set_noise_kernel_fn(RBFKernel)
assert_(jnp.array_equal(noise_rbf(X, X, params), RBFKernel(X, X, params_i)))
Loading