diff --git a/gpax/models/vi_mtdkl.py b/gpax/models/vi_mtdkl.py index b1cdc0a..c84a9f3 100644 --- a/gpax/models/vi_mtdkl.py +++ b/gpax/models/vi_mtdkl.py @@ -217,7 +217,7 @@ def get_mvn_posterior(self, z_test = self.nn_module.apply( nn_params, jax.random.PRNGKey(0), X_new if self.shared_input else X_new[:, :-1]) - if self.shared_input: + if not self.shared_input: z_train = jnp.column_stack((z_train, X_train[:, -1])) z_test = jnp.column_stack((z_test, X_new[:, -1])) # compute kernel matrices for train and test data diff --git a/gpax/models/vidkl.py b/gpax/models/vidkl.py index eb38bd9..b166c91 100644 --- a/gpax/models/vidkl.py +++ b/gpax/models/vidkl.py @@ -8,7 +8,7 @@ """ from functools import partial -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -30,7 +30,7 @@ class viDKL(ExactGP): Args: input_dim: - Number of input dimensions + Input features dimensions (e.g. 64*64 for a stack of flattened 64-by-64 images) z_dim: Latent space dimensionality (defaults to 2) kernel: @@ -66,7 +66,7 @@ class viDKL(ExactGP): >>> y_mean, y_var = dkl.predict(key2, X_new) """ - def __init__(self, input_dim: int, z_dim: int = 2, kernel: str = 'RBF', + def __init__(self, input_dim: Union[int, Tuple[int]], z_dim: int = 2, kernel: str = 'RBF', kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, nn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None, @@ -229,6 +229,7 @@ def sample_from_posterior(self, rng_key: jnp.ndarray, def predict_in_batches(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, batch_size: int = 100, + params: Optional[Dict[str, jnp.ndarray]] = None, noiseless: bool = False, **kwargs ) -> Tuple[jnp.ndarray, jnp.ndarray]: @@ -237,10 +238,11 @@ def predict_in_batches(self, rng_key: jnp.ndarray, by spitting the input array into chunks ("batches") and running self.predict on each of them one-by-one to avoid a memory overflow """ - predict_fn = lambda xi: self.predict(rng_key, xi, noiseless=noiseless, **kwargs) + predict_fn = lambda xi: self.predict( + rng_key, xi, params, noiseless=noiseless, **kwargs) cat_dim = 1 if self.X_train.ndim == len(self.data_dim) + 2 else 0 mean, var = self._predict_in_batches( - rng_key, X_new, batch_size, cat_dim, predict_fn=predict_fn) + rng_key, X_new, batch_size, cat_dim, params, predict_fn=predict_fn) mean = jnp.concatenate(mean, cat_dim) var = jnp.concatenate(var, cat_dim) return mean, var @@ -319,7 +321,7 @@ def single_fit_predict(key): self.fit(key, X, y, num_steps, step_size, print_summary, progress_bar, **kwargs) mean, var = self.predict_in_batches( - key, X_new, batch_size, noiseless, **kwargs) + key, X_new, batch_size, None, noiseless, **kwargs) return mean, var if n_models > 1 and ensemble_method not in ["vectorized", "parallel"]: diff --git a/gpax/models/vigp.py b/gpax/models/vigp.py index f68781c..7ddd968 100644 --- a/gpax/models/vigp.py +++ b/gpax/models/vigp.py @@ -138,13 +138,13 @@ def predict_in_batches(self, rng_key: jnp.ndarray, """ predict_fn = lambda xi: self.predict( rng_key, xi, samples, noiseless, **kwargs) - y_pred, y_sampled = self._predict_in_batches( + y_pred, y_var = self._predict_in_batches( rng_key, X_new, batch_size, 0, samples, predict_fn=predict_fn, noiseless=noiseless, device=device, **kwargs) y_pred = jnp.concatenate(y_pred, 0) - y_sampled = jnp.concatenate(y_sampled, -1) - return y_pred, y_sampled + y_var = jnp.concatenate(y_var, 0) + return y_pred, y_var def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, samples: Optional[Dict[str, jnp.ndarray]] = None,