diff --git a/gpax/models/hskgp.py b/gpax/models/hskgp.py index 0c4be5a..ec92016 100644 --- a/gpax/models/hskgp.py +++ b/gpax/models/hskgp.py @@ -19,8 +19,6 @@ kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] -clear_cache = jax._src.dispatch.xla_primitive_callable.cache_clear - class VarNoiseGP(ExactGP): """ @@ -70,8 +68,8 @@ def __init__( self.noise_lengthscale_prior_dist = noise_lengthscale_prior_dist def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: - """GP probabilistic model with inputs X and targets y""" - # Initialize mean function at zeros + """Heteroskedastic GP probabilistic model with inputs X and targets y""" + # Initialize mean functions at zeros f_loc = jnp.zeros(X.shape[0]) noise_f_loc = jnp.zeros(X.shape[0]) @@ -102,10 +100,10 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: if self.mean_fn_prior is not None: args += [self.mean_fn_prior()] f_loc += self.mean_fn(*args).squeeze() - # compute main kernel + # Compute main kernel k = self.kernel(X, X, kernel_params, 0, **kwargs) - # Sample y according to the standard Gaussian process formula. Note that instead of adding a fixed noise term to the kernel, - # we exponentiate the log_var samples to get the variance at each data point + # Sample y according to the standard Gaussian process formula. Note that instead of adding a fixed noise term to the kernel as in regular GP, + # we exponentiate and take a diagonal of the log_var samples to get the variance at each data point and add that variance to the main kernel numpyro.sample( "y", dist.MultivariateNormal(loc=f_loc, covariance_matrix=k+jnp.diag(jnp.exp(points_log_var))), @@ -129,18 +127,18 @@ def get_mvn_posterior( ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Returns parameters (mean and cov) of multivariate normal posterior - for a single sample of GP parameters + for a single sample of heteroskedastic GP parameters """ # Main GP part y_residual = self.y_train.copy() if self.mean_fn is not None: args = [self.X_train, params] if self.mean_fn_prior else [self.X_train] y_residual -= self.mean_fn(*args).squeeze() - # compute main kernel matrices for train and test data + # Compute main kernel matrices for train and test data k_pp = self.kernel(X_new, X_new, params, 0, **kwargs) k_pX = self.kernel(X_new, self.X_train, params, jitter=0.0) k_XX = self.kernel(self.X_train, self.X_train, params, 0, **kwargs) - # compute the predictive covariance and mean + # Compute the predictive covariance and mean K_xx_inv = jnp.linalg.inv(k_XX) cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) @@ -174,7 +172,7 @@ def get_mvn_posterior( return mean, cov + jnp.diag(predicted_noise_variance) def get_data_var_samples(self): - """Returns inferred (training) data variance samples""" + """Returns samples with inferred (training) data variance - aka noise""" samples = self.mcmc.get_samples() log_var = samples["log_var"] if self.noise_mean_fn is not None: