Skip to content

Commit

Permalink
Added the ability to use cholesky decomposition instead of naive inve…
Browse files Browse the repository at this point in the history
…rse for the ExactGP class. Added a test for cholesky decomposition
  • Loading branch information
mjbajwa committed Sep 13, 2024
1 parent fad59bc commit 924ee82
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
33 changes: 24 additions & 9 deletions gpax/models/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]:
return self.mcmc.get_samples(group_by_chain=chain_dim)

def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, use_cholesky: bool = False, **kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
Expand All @@ -267,13 +267,24 @@ def get_mvn_posterior(
k_pp = self.kernel(X_new, X_new, params, noise_p, **kwargs)
k_pX = self.kernel(X_new, self.X_train, params, jitter=0.0)
k_XX = self.kernel(self.X_train, self.X_train, params, noise, **kwargs)
# compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))

# Compute the predictive covariance and mean
# since K_xx is symmetric positive-definite, we can use the more efficient and
# stable Cholesky decomposition instead of matrix inversion

if use_cholesky:
K_xx_cho = jax.scipy.linalg.cho_factor(k_XX)
cov = k_pp - jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, k_pX.T))
mean = jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, y_residual))
else:
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))

if self.mean_fn is not None:
args = [X_new, params] if self.mean_fn_prior else [X_new]
mean += self.mean_fn(*args).squeeze()

return mean, cov

def _predict(
Expand All @@ -283,11 +294,12 @@ def _predict(
params: Dict[str, jnp.ndarray],
n: int,
noiseless: bool = False,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Prediction with a single sample of GP parameters"""
# Get the predictive mean and covariance
y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, **kwargs)
y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, use_cholesky, **kwargs)
# draw samples from the posterior predictive for a given set of parameters
y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,))
return y_mean, y_sampled
Expand All @@ -304,10 +316,11 @@ def _predict_in_batches(
predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
if predict_fn is None:
predict_fn = lambda xi: self.predict(rng_key, xi, samples, n, filter_nans, noiseless, device, **kwargs)
predict_fn = lambda xi: self.predict(rng_key, xi, samples, n, filter_nans, noiseless, device, use_cholesky, **kwargs)

def predict_batch(Xi):
out1, out2 = predict_fn(Xi)
Expand All @@ -333,6 +346,7 @@ def predict_in_batches(
predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Expand All @@ -342,7 +356,7 @@ def predict_in_batches(
to avoid a memory overflow
"""
y_pred, y_sampled = self._predict_in_batches(
rng_key, X_new, batch_size, 0, samples, n, filter_nans, predict_fn, noiseless, device, **kwargs
rng_key, X_new, batch_size, 0, samples, n, filter_nans, predict_fn, noiseless, device, use_cholesky, **kwargs
)
y_pred = jnp.concatenate(y_pred, 0)
y_sampled = jnp.concatenate(y_sampled, -1)
Expand All @@ -357,6 +371,7 @@ def predict(
filter_nans: bool = False,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
use_cholesky: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Expand Down Expand Up @@ -391,7 +406,7 @@ def predict(
samples = jax.device_put(samples, device)
num_samples = len(next(iter(samples.values())))
vmap_args = (jra.split(rng_key, num_samples), samples)
predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, **kwargs))
predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, use_cholesky, **kwargs))
y_means, y_sampled = predictive(vmap_args)
if filter_nans:
y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()]
Expand Down
15 changes: 15 additions & 0 deletions tests/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,21 @@ def test_get_mvn_posterior_noiseless():
assert_array_equal(mean1, mean2)
assert onp.count_nonzero(cov1 - cov2) > 0

def test_get_mvn_posterior_cholesky():
X, y = get_dummy_data(unsqueeze=True)
X_test, _ = get_dummy_data(unsqueeze=True)
params = {"k_length": jnp.array([1.0]),
"k_scale": jnp.array(1.0),
"noise": jnp.array(0.1)}
m = ExactGP(1, 'RBF')
m.X_train = X
m.y_train = y
mean, cov = m.get_mvn_posterior(X_test, params, use_cholesky=True)
assert isinstance(mean, jnp.ndarray)
assert isinstance(cov, jnp.ndarray)
assert_equal(mean.shape, (X_test.shape[0],))
assert_equal(cov.shape, (X_test.shape[0], X_test.shape[0]))


def test_single_sample_prediction():
rng_key = get_keys()[0]
Expand Down

0 comments on commit 924ee82

Please sign in to comment.