Skip to content

Commit

Permalink
Merge pull request #32 from ziatdinovmax/mpdkl
Browse files Browse the repository at this point in the history
Minor bug fixes for vidkl and vigp
  • Loading branch information
ziatdinovmax authored Aug 7, 2023
2 parents 63bcebd + 5fdf5c3 commit 728b59b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion gpax/models/vi_mtdkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def get_mvn_posterior(self,
z_test = self.nn_module.apply(
nn_params, jax.random.PRNGKey(0),
X_new if self.shared_input else X_new[:, :-1])
if self.shared_input:
if not self.shared_input:
z_train = jnp.column_stack((z_train, X_train[:, -1]))
z_test = jnp.column_stack((z_test, X_new[:, -1]))
# compute kernel matrices for train and test data
Expand Down
14 changes: 8 additions & 6 deletions gpax/models/vidkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

from functools import partial
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple, Union

import jax
import jax.numpy as jnp
Expand All @@ -30,7 +30,7 @@ class viDKL(ExactGP):
Args:
input_dim:
Number of input dimensions
Input features dimensions (e.g. 64*64 for a stack of flattened 64-by-64 images)
z_dim:
Latent space dimensionality (defaults to 2)
kernel:
Expand Down Expand Up @@ -66,7 +66,7 @@ class viDKL(ExactGP):
>>> y_mean, y_var = dkl.predict(key2, X_new)
"""

def __init__(self, input_dim: int, z_dim: int = 2, kernel: str = 'RBF',
def __init__(self, input_dim: Union[int, Tuple[int]], z_dim: int = 2, kernel: str = 'RBF',
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
nn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None,
Expand Down Expand Up @@ -229,6 +229,7 @@ def sample_from_posterior(self, rng_key: jnp.ndarray,

def predict_in_batches(self, rng_key: jnp.ndarray,
X_new: jnp.ndarray, batch_size: int = 100,
params: Optional[Dict[str, jnp.ndarray]] = None,
noiseless: bool = False,
**kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray]:
Expand All @@ -237,10 +238,11 @@ def predict_in_batches(self, rng_key: jnp.ndarray,
by spitting the input array into chunks ("batches") and running
self.predict on each of them one-by-one to avoid a memory overflow
"""
predict_fn = lambda xi: self.predict(rng_key, xi, noiseless=noiseless, **kwargs)
predict_fn = lambda xi: self.predict(
rng_key, xi, params, noiseless=noiseless, **kwargs)
cat_dim = 1 if self.X_train.ndim == len(self.data_dim) + 2 else 0
mean, var = self._predict_in_batches(
rng_key, X_new, batch_size, cat_dim, predict_fn=predict_fn)
rng_key, X_new, batch_size, cat_dim, params, predict_fn=predict_fn)
mean = jnp.concatenate(mean, cat_dim)
var = jnp.concatenate(var, cat_dim)
return mean, var
Expand Down Expand Up @@ -319,7 +321,7 @@ def single_fit_predict(key):
self.fit(key, X, y, num_steps, step_size,
print_summary, progress_bar, **kwargs)
mean, var = self.predict_in_batches(
key, X_new, batch_size, noiseless, **kwargs)
key, X_new, batch_size, None, noiseless, **kwargs)
return mean, var

if n_models > 1 and ensemble_method not in ["vectorized", "parallel"]:
Expand Down
6 changes: 3 additions & 3 deletions gpax/models/vigp.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def predict_in_batches(self, rng_key: jnp.ndarray,
"""
predict_fn = lambda xi: self.predict(
rng_key, xi, samples, noiseless, **kwargs)
y_pred, y_sampled = self._predict_in_batches(
y_pred, y_var = self._predict_in_batches(
rng_key, X_new, batch_size, 0, samples,
predict_fn=predict_fn, noiseless=noiseless,
device=device, **kwargs)
y_pred = jnp.concatenate(y_pred, 0)
y_sampled = jnp.concatenate(y_sampled, -1)
return y_pred, y_sampled
y_var = jnp.concatenate(y_var, 0)
return y_pred, y_var

def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
samples: Optional[Dict[str, jnp.ndarray]] = None,
Expand Down

0 comments on commit 728b59b

Please sign in to comment.