Skip to content

Commit

Permalink
Improve code readability
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Oct 7, 2023
1 parent 7e2a2dd commit 72e6b94
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions gpax/models/hskgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]

clear_cache = jax._src.dispatch.xla_primitive_callable.cache_clear


class VarNoiseGP(ExactGP):
"""
Expand Down Expand Up @@ -70,8 +68,8 @@ def __init__(
self.noise_lengthscale_prior_dist = noise_lengthscale_prior_dist

def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
"""GP probabilistic model with inputs X and targets y"""
# Initialize mean function at zeros
"""Heteroskedastic GP probabilistic model with inputs X and targets y"""
# Initialize mean functions at zeros
f_loc = jnp.zeros(X.shape[0])
noise_f_loc = jnp.zeros(X.shape[0])

Expand Down Expand Up @@ -102,10 +100,10 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
if self.mean_fn_prior is not None:
args += [self.mean_fn_prior()]
f_loc += self.mean_fn(*args).squeeze()
# compute main kernel
# Compute main kernel
k = self.kernel(X, X, kernel_params, 0, **kwargs)
# Sample y according to the standard Gaussian process formula. Note that instead of adding a fixed noise term to the kernel,
# we exponentiate the log_var samples to get the variance at each data point
# Sample y according to the standard Gaussian process formula. Note that instead of adding a fixed noise term to the kernel as in regular GP,
# we exponentiate and take a diagonal of the log_var samples to get the variance at each data point and add that variance to the main kernel
numpyro.sample(
"y",
dist.MultivariateNormal(loc=f_loc, covariance_matrix=k+jnp.diag(jnp.exp(points_log_var))),
Expand All @@ -129,18 +127,18 @@ def get_mvn_posterior(
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
for a single sample of GP parameters
for a single sample of heteroskedastic GP parameters
"""
# Main GP part
y_residual = self.y_train.copy()
if self.mean_fn is not None:
args = [self.X_train, params] if self.mean_fn_prior else [self.X_train]
y_residual -= self.mean_fn(*args).squeeze()
# compute main kernel matrices for train and test data
# Compute main kernel matrices for train and test data
k_pp = self.kernel(X_new, X_new, params, 0, **kwargs)
k_pX = self.kernel(X_new, self.X_train, params, jitter=0.0)
k_XX = self.kernel(self.X_train, self.X_train, params, 0, **kwargs)
# compute the predictive covariance and mean
# Compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))
Expand Down Expand Up @@ -174,7 +172,7 @@ def get_mvn_posterior(
return mean, cov + jnp.diag(predicted_noise_variance)

def get_data_var_samples(self):
"""Returns inferred (training) data variance samples"""
"""Returns samples with inferred (training) data variance - aka noise"""
samples = self.mcmc.get_samples()
log_var = samples["log_var"]
if self.noise_mean_fn is not None:
Expand Down

0 comments on commit 72e6b94

Please sign in to comment.