diff --git a/gpax/models/vi_mtdkl.py b/gpax/models/vi_mtdkl.py index c84a9f3..8b59b79 100644 --- a/gpax/models/vi_mtdkl.py +++ b/gpax/models/vi_mtdkl.py @@ -195,12 +195,11 @@ def _sample_kernel_params(self): @partial(jit, static_argnames='self') def get_mvn_posterior(self, - X_train: jnp.ndarray, - y_train: jnp.ndarray, X_new: jnp.ndarray, nn_params: Dict[str, jnp.ndarray], k_params: Dict[str, jnp.ndarray], noiseless: bool = False, + y_residual: jnp.ndarray = None, **kwargs ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ @@ -208,17 +207,19 @@ def get_mvn_posterior(self, (mean and cov, where cov.diagonal() is 'uncertainty') given a single set of DKL parameters """ + if y_residual is None: + y_residual = self.y_train noise = k_params.pop("noise") noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn_module.apply( nn_params, jax.random.PRNGKey(0), - X_train if self.shared_input else X_train[:, :-1]) + self.X_train if self.shared_input else self.X_train[:, :-1]) z_test = self.nn_module.apply( nn_params, jax.random.PRNGKey(0), X_new if self.shared_input else X_new[:, :-1]) if not self.shared_input: - z_train = jnp.column_stack((z_train, X_train[:, -1])) + z_train = jnp.column_stack((z_train, self.X_train[:, -1])) z_test = jnp.column_stack((z_test, X_new[:, -1])) # compute kernel matrices for train and test data k_pp = self.kernel(z_test, z_test, k_params, noise_p, **kwargs) @@ -227,5 +228,5 @@ def get_mvn_posterior(self, # compute the predictive covariance and mean K_xx_inv = jnp.linalg.inv(k_XX) cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) - mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_train)) + mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) return mean, cov