Skip to content

Commit

Permalink
Add 'get_samples' for viDKL
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Aug 16, 2023
1 parent 6516994 commit 484f70c
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions gpax/models/vidkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ def sample_from_posterior(self, rng_key: jnp.ndarray,
X_new, self.nn_params, self.kernel_params, noiseless, **kwargs)
y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,))
return y_mean, y_sampled

def get_samples(self) -> Tuple[Dict['str', jnp.ndarray]]:
"""Returns a tuple with trained NN weights and kernel hyperparameters"""
return self.nn_params, self.kernel_params

def predict_in_batches(self, rng_key: jnp.ndarray,
X_new: jnp.ndarray, batch_size: int = 100,
Expand Down

0 comments on commit 484f70c

Please sign in to comment.