diff --git a/gpax/kernels/mtkernels.py b/gpax/kernels/mtkernels.py index 9765a2b..654d189 100644 --- a/gpax/kernels/mtkernels.py +++ b/gpax/kernels/mtkernels.py @@ -17,6 +17,9 @@ kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray] +# Helper function to generate in_axes dictionary +get_in_axes = lambda data: ({key: 0 if key != "noise" else None for key in data.keys()},) + def index_kernel(indices1, indices2, params): r""" @@ -223,7 +226,8 @@ def LCMKernel(base_kernel, shared_input_space=True, num_tasks=None, **kwargs1): multi_kernel = MultitaskKernel(base_kernel, **kwargs1) def lcm_kernel(X, Z, params, noise=0, **kwargs2): - k = vmap(lambda p: multi_kernel(X, Z, p, noise, **kwargs2))(params) + axes = get_in_axes(params) + k = vmap(lambda p: multi_kernel(X, Z, p, noise, **kwargs2), in_axes=axes)(params) return k.sum(0) - return lcm_kernel \ No newline at end of file + return lcm_kernel