Skip to content

Commit

Permalink
Streamline vmap compute in vi DKL
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Aug 15, 2023
1 parent 13ccae5 commit 39c392e
Showing 1 changed file with 35 additions and 25 deletions.
60 changes: 35 additions & 25 deletions gpax/models/vidkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ class viDKL(ExactGP):
Optional prior over the latent space (NN embedding); uses none by default
guide:
Auto-guide option, use 'delta' (default) or 'normal'
**kwargs:
Optional custom prior distributions over observational noise (noise_dist_prior)
and kernel lengthscale (lengthscale_prior_dist)
Examples:
vi-DKL with image patches as inputs and a 1-d vector as targets
Expand Down Expand Up @@ -159,22 +159,27 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
print_summary: print summary at the end of sampling
progress_bar: show progress bar (works only for scalar outputs)
"""
def _single_fit(x_i, y_i):
return self.single_fit(
rng_key, x_i, y_i, num_steps, step_size,
print_summary=False, progress_bar=False, **kwargs)

self.X_train = X
self.y_train = y

if X.ndim == len(self.data_dim) + 2:
self.nn_params, self.kernel_params, self.loss = jax.vmap(_single_fit)(X, y)
if y.ndim == 2: # y has shape (channels, samples), so so we use vmap to fit all channels in parallel

# Define a wrapper to use with vmap
def _single_fit(yi):
return self.single_fit(
rng_key, X, yi, num_steps, step_size,
print_summary=False, progress_bar=False, **kwargs)
# Apply vmap to the wrapper function
vfit = jax.vmap(_single_fit)
self.nn_params, self.kernel_params, self.loss = vfit(y)
# Poor man version of the progress bar
if progress_bar:
avg_bw = [num_steps - num_steps // 20, num_steps]
print("init loss: {}, final loss (avg) [{}-{}]: {} ".format(
self.loss[0].mean(), avg_bw[0], avg_bw[1],
self.loss.mean(0)[avg_bw[0]:avg_bw[1]].mean().round(4)))
else:

else: # no channel dimension so we use the regular single_fit
self.nn_params, self.kernel_params, self.loss = self.single_fit(
rng_key, X, y, num_steps, step_size, print_summary, progress_bar
)
Expand All @@ -183,24 +188,25 @@ def _single_fit(x_i, y_i):

@partial(jit, static_argnames='self')
def get_mvn_posterior(self,
X_train: jnp.ndarray,
y_train: jnp.ndarray,
X_new: jnp.ndarray,
nn_params: Dict[str, jnp.ndarray],
k_params: Dict[str, jnp.ndarray],
noiseless: bool = False,
y_residual: jnp.ndarray = None,
**kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns predictive mean and covariance at new points
(mean and cov, where cov.diagonal() is 'uncertainty')
given a single set of DKL parameters
"""
if y_residual is None:
y_residual = self.y_train
noise = k_params.pop("noise")
noise_p = noise * (1 - jnp.array(noiseless, int))
# embed data into the latent space
z_train = self.nn_module.apply(
nn_params, jax.random.PRNGKey(0), X_train)
nn_params, jax.random.PRNGKey(0), self.X_train)
z_test = self.nn_module.apply(
nn_params, jax.random.PRNGKey(0), X_new)
# compute kernel matrices for train and test data
Expand All @@ -210,7 +216,7 @@ def get_mvn_posterior(self,
# compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_train))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))
return mean, cov

def sample_from_posterior(self, rng_key: jnp.ndarray,
Expand Down Expand Up @@ -266,23 +272,27 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
Returns:
Predictive mean and variance
"""

def single_predict(x_train_i, y_train_i, x_new_i, nnpar_i, kpar_i):
mean, cov = self.get_mvn_posterior(
x_train_i, y_train_i, x_new_i, nnpar_i, kpar_i, noiseless, **kwargs)
return mean, cov.diagonal()

if params is None:
nn_params = self.nn_params
k_params = self.kernel_params
else:
nn_params, k_params = params

p_args = (self.X_train, self.y_train, X_new, nn_params, k_params)
if self.X_train.ndim == len(self.data_dim) + 2:
mean, var = jax.vmap(single_predict)(*p_args)
else:
mean, var = single_predict(*p_args)
if self.y_train.ndim == 2: # y has shape (channels, samples)
# Define a wrapper to use with vmap
def _get_mvn_posterior(nn_params_i, k_params_i, yi):
mean, cov = self.get_mvn_posterior(
X_new, nn_params_i, k_params_i, noiseless, yi)
return mean, cov.diagonal()
# vectorize posterior predictive computation over the y's channel dimension
predictive = jax.vmap(_get_mvn_posterior)
mean, var = predictive(nn_params, k_params, self.y_train)

else: # y has shape (samples,)
# Standard prediction
mean, cov = self.get_mvn_posterior(
X_new, nn_params, k_params, noiseless)
var = cov.diagonal()

return mean, var

Expand Down

0 comments on commit 39c392e

Please sign in to comment.