From 65169949eac47e071c4db9fef09d12cd353aa397 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 15 Aug 2023 18:21:56 -0400 Subject: [PATCH] fix sample_from_posterior in response to earlier change in the arg sequence --- gpax/models/vidkl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gpax/models/vidkl.py b/gpax/models/vidkl.py index f810226..54517ee 100644 --- a/gpax/models/vidkl.py +++ b/gpax/models/vidkl.py @@ -227,9 +227,10 @@ def sample_from_posterior(self, rng_key: jnp.ndarray, """ Samples from the DKL posterior at X_new points """ + if self.y_train.ndim > 1: + raise NotImplementedError("Currently does not support a multi-channel regime") y_mean, K = self.get_mvn_posterior( - self.X_train, self.y_train, X_new, - self.nn_params, self.kernel_params, noiseless, **kwargs) + X_new, self.nn_params, self.kernel_params, noiseless, **kwargs) y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,)) return y_mean, y_sampled