diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py index 2413c9d..57f1bf3 100644 --- a/gpax/acquisition/batch_acquisition.py +++ b/gpax/acquisition/batch_acquisition.py @@ -31,6 +31,8 @@ def compute_batch_acquisition(acquisition_type: Callable, """ if model.mcmc is None: raise ValueError("The model needs to be fully Bayesian") + + X = X[:, None] if X.ndim < 2 else X samples = random_sample_dict(model.get_samples(), subsample_size) f = vmap(acquisition_type, in_axes=(None, None, 0) + (None,) * len(acq_args))