Skip to content

Commit

Permalink
Merge pull request #48 from ziatdinovmax/varnoise
Browse files Browse the repository at this point in the history
Add option to specify custom noise kernels
  • Loading branch information
ziatdinovmax authored Oct 11, 2023
2 parents 1da4588 + 8f69aa4 commit 95abe51
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 14 deletions.
24 changes: 13 additions & 11 deletions gpax/models/hskgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from . import ExactGP
from ..kernels import get_kernel
from ..utils import _set_noise_kernel_fn

kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]

Expand Down Expand Up @@ -55,16 +56,20 @@ def __init__(
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
lengthscale_prior_dist: Optional[dist.Distribution] = None,
noise_mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
noise_mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_lengthscale_prior_dist: Optional[dist.Distribution] = None
) -> None:
args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, None, None, lengthscale_prior_dist)
super(VarNoiseGP, self).__init__(*args)
self.noise_kernel = get_kernel(noise_kernel)
noise_kernel_ = get_kernel(noise_kernel)
self.noise_kernel = _set_noise_kernel_fn(noise_kernel_) if isinstance(noise_kernel, str) else noise_kernel_

self.noise_mean_fn = noise_mean_fn
self.noise_mean_fn_prior = noise_mean_fn_prior
self.noise_kernel_prior = noise_kernel_prior
self.noise_lengthscale_prior_dist = noise_lengthscale_prior_dist

def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
Expand All @@ -74,7 +79,10 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
noise_f_loc = jnp.zeros(X.shape[0])

# Sample noise kernel parameters
noise_kernel_params = self._sample_noise_kernel_params()
if self.noise_kernel_prior:
noise_kernel_params = self.noise_kernel_prior()
else:
noise_kernel_params = self._sample_noise_kernel_params()
# Add noise prior mean function (if any)
if self.noise_mean_fn is not None:
args = [X]
Expand Down Expand Up @@ -120,7 +128,7 @@ def _sample_noise_kernel_params(self) -> Dict[str, jnp.ndarray]:
noise_length_dist = dist.LogNormal(0, 1)
noise_scale = numpyro.sample("k_noise_scale", dist.LogNormal(0, 1))
noise_length = numpyro.sample("k_noise_length", noise_length_dist)
return {"k_length": noise_length, "k_scale": noise_scale}
return {"k_noise_length": noise_length, "k_noise_scale": noise_scale}

def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], *args, **kwargs
Expand Down Expand Up @@ -148,14 +156,8 @@ def get_mvn_posterior(

# Noise GP part
# Compute noise kernel matrices
k_pX_noise = self.noise_kernel(
X_new, self.X_train,
{"k_length": params["k_noise_length"], "k_scale": params["k_noise_scale"]},
jitter=0.0)
k_XX_noise = self.noise_kernel(
self.X_train, self.X_train,
{"k_length": params["k_noise_length"], "k_scale": params["k_noise_scale"]},
0, **kwargs)
k_pX_noise = self.noise_kernel(X_new, self.X_train, params, jitter=0.0)
k_XX_noise = self.noise_kernel(self.X_train, self.X_train, params, 0, **kwargs)
# Compute noise predictive mean
log_var_residual = params["log_var"].copy()
if self.noise_mean_fn is not None:
Expand Down
3 changes: 2 additions & 1 deletion gpax/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .utils import *
from .priors import *
from .priors import *
from .priors import _set_noise_kernel_fn
37 changes: 36 additions & 1 deletion gpax/utils/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import jax
import jax.numpy as jnp

from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt


def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
"""
Expand Down Expand Up @@ -227,12 +229,45 @@ def set_kernel_fn(func: Callable,

transformed_code += custom_code

local_namespace = {"jit": jax.jit}
local_namespace = {"jit": jax.jit}
exec(transformed_code, globals(), local_namespace)

return local_namespace[func.__name__]


def _set_noise_kernel_fn(func: Callable) -> Callable:
"""
Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses.
Args:
func (Callable): Original function.
Returns:
Callable: Modified function.
"""

# Get the source code of the function
source = inspect.getsource(func)

# Split the source into decorators, definition, and body
decorators_and_def, body = source.split("\n", 1)

# Replace all occurrences of params["k with params["k_noise in the body
modified_body = re.sub(r'params\["k', 'params["k_noise', body)

# Combine decorators, definition, and modified body
modified_source = f"{decorators_and_def}\n{modified_body}"

# Define local namespace including the jit decorator
local_namespace = {"jit": jax.jit}

# Execute the modified source to redefine the function in the provided namespace
exec(modified_source, globals(), local_namespace)

# Return the modified function
return local_namespace[func.__name__]


def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable:
"""
Generates a function that, when invoked, samples from normal or log-normal distributions
Expand Down
13 changes: 12 additions & 1 deletion tests/test_utilpriors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from gpax.utils import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, place_lognormal_prior
from gpax.utils import uniform_dist, normal_dist, halfnormal_dist, lognormal_dist, gamma_dist
from gpax.utils import set_fn, set_kernel_fn, auto_lognormal_priors, auto_normal_priors, auto_lognormal_kernel_priors, auto_normal_kernel_priors, auto_priors
from gpax.utils import auto_lognormal_priors, auto_normal_priors, auto_lognormal_kernel_priors, auto_normal_kernel_priors, auto_priors
from gpax.utils import set_fn, set_kernel_fn, _set_noise_kernel_fn


def linear_kernel_test(X, Z, k_scale):
Expand Down Expand Up @@ -233,3 +234,13 @@ def test_auto_normal_kernel_priors(autopriors):
with numpyro.handlers.trace() as tr:
priors_fn()
assert_('k_scale' in tr)


def test_set_noise_kernel_fn():
from gpax.kernels import RBFKernel

X = jnp.array([[1, 2], [3, 4], [5, 6]])
params_i = {"k_length": jnp.array([1.0]), "k_scale": jnp.array(1.0)}
params = {"k_noise_length": jnp.array([1.0]), "k_noise_scale": jnp.array(1.0)}
noise_rbf = _set_noise_kernel_fn(RBFKernel)
assert_(jnp.array_equal(noise_rbf(X, X, params), RBFKernel(X, X, params_i)))

0 comments on commit 95abe51

Please sign in to comment.