From e5e37d786e0dfc6faed28bdfb174e102f130a030 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Thu, 24 Aug 2023 00:44:25 -0400 Subject: [PATCH] Exclude "noise" from vmap in multi-task kernel (it might make sense to do it for other kernels as well) --- gpax/kernels/mtkernels.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gpax/kernels/mtkernels.py b/gpax/kernels/mtkernels.py index 9765a2b..654d189 100644 --- a/gpax/kernels/mtkernels.py +++ b/gpax/kernels/mtkernels.py @@ -17,6 +17,9 @@ kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] +# Helper function to generate in_axes dictionary +get_in_axes = lambda data: ({key: 0 if key != "noise" else None for key in data.keys()},) + def index_kernel(indices1, indices2, params): r""" @@ -223,7 +226,8 @@ def LCMKernel(base_kernel, shared_input_space=True, num_tasks=None, **kwargs1): multi_kernel = MultitaskKernel(base_kernel, **kwargs1) def lcm_kernel(X, Z, params, noise=0, **kwargs2): - k = vmap(lambda p: multi_kernel(X, Z, p, noise, **kwargs2))(params) + axes = get_in_axes(params) + k = vmap(lambda p: multi_kernel(X, Z, p, noise, **kwargs2), in_axes=axes)(params) return k.sum(0) - return lcm_kernel \ No newline at end of file + return lcm_kernel