From cf52a6d5f4d40e7eb1c860c2045ff0d7ad3f12a9 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 20 Aug 2023 19:11:22 -0400 Subject: [PATCH] check input dimensionality --- gpax/acquisition/batch_acquisition.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py index 2413c9d..57f1bf3 100644 --- a/gpax/acquisition/batch_acquisition.py +++ b/gpax/acquisition/batch_acquisition.py @@ -31,6 +31,8 @@ def compute_batch_acquisition(acquisition_type: Callable, """ if model.mcmc is None: raise ValueError("The model needs to be fully Bayesian") + + X = X[:, None] if X.ndim < 2 else X samples = random_sample_dict(model.get_samples(), subsample_size) f = vmap(acquisition_type, in_axes=(None, None, 0) + (None,) * len(acq_args))