Skip to content

Commit

Permalink
fix sample_from_posterior
Browse files Browse the repository at this point in the history
in response to earlier change in the arg sequence
  • Loading branch information
ziatdinovmax committed Aug 15, 2023
1 parent 2cea61e commit 6516994
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions gpax/models/vidkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,10 @@ def sample_from_posterior(self, rng_key: jnp.ndarray,
"""
Samples from the DKL posterior at X_new points
"""
if self.y_train.ndim > 1:
raise NotImplementedError("Currently does not support a multi-channel regime")
y_mean, K = self.get_mvn_posterior(
self.X_train, self.y_train, X_new,
self.nn_params, self.kernel_params, noiseless, **kwargs)
X_new, self.nn_params, self.kernel_params, noiseless, **kwargs)
y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,))
return y_mean, y_sampled

Expand Down

0 comments on commit 6516994

Please sign in to comment.