Skip to content

Commit

Permalink
Update the args sequence in get_mvn_posterior
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Aug 15, 2023
1 parent 06fa254 commit 2cea61e
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions gpax/models/vi_mtdkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,30 +195,31 @@ def _sample_kernel_params(self):

@partial(jit, static_argnames='self')
def get_mvn_posterior(self,
X_train: jnp.ndarray,
y_train: jnp.ndarray,
X_new: jnp.ndarray,
nn_params: Dict[str, jnp.ndarray],
k_params: Dict[str, jnp.ndarray],
noiseless: bool = False,
y_residual: jnp.ndarray = None,
**kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns predictive mean and covariance at new points
(mean and cov, where cov.diagonal() is 'uncertainty')
given a single set of DKL parameters
"""
if y_residual is None:
y_residual = self.y_train
noise = k_params.pop("noise")
noise_p = noise * (1 - jnp.array(noiseless, int))
# embed data into the latent space
z_train = self.nn_module.apply(
nn_params, jax.random.PRNGKey(0),
X_train if self.shared_input else X_train[:, :-1])
self.X_train if self.shared_input else self.X_train[:, :-1])
z_test = self.nn_module.apply(
nn_params, jax.random.PRNGKey(0),
X_new if self.shared_input else X_new[:, :-1])
if not self.shared_input:
z_train = jnp.column_stack((z_train, X_train[:, -1]))
z_train = jnp.column_stack((z_train, self.X_train[:, -1]))
z_test = jnp.column_stack((z_test, X_new[:, -1]))
# compute kernel matrices for train and test data
k_pp = self.kernel(z_test, z_test, k_params, noise_p, **kwargs)
Expand All @@ -227,5 +228,5 @@ def get_mvn_posterior(self,
# compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_train))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))
return mean, cov

0 comments on commit 2cea61e

Please sign in to comment.