Skip to content

Commit

Permalink
Exclude "noise" from vmap in multi-task kernel
Browse files Browse the repository at this point in the history
(it might make sense to do it for other kernels as well)
  • Loading branch information
ziatdinovmax committed Aug 24, 2023
1 parent 218c0a7 commit e5e37d7
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions gpax/kernels/mtkernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]

# Helper function to generate in_axes dictionary
get_in_axes = lambda data: ({key: 0 if key != "noise" else None for key in data.keys()},)


def index_kernel(indices1, indices2, params):
r"""
Expand Down Expand Up @@ -223,7 +226,8 @@ def LCMKernel(base_kernel, shared_input_space=True, num_tasks=None, **kwargs1):
multi_kernel = MultitaskKernel(base_kernel, **kwargs1)

def lcm_kernel(X, Z, params, noise=0, **kwargs2):
k = vmap(lambda p: multi_kernel(X, Z, p, noise, **kwargs2))(params)
axes = get_in_axes(params)
k = vmap(lambda p: multi_kernel(X, Z, p, noise, **kwargs2), in_axes=axes)(params)
return k.sum(0)

return lcm_kernel
return lcm_kernel

0 comments on commit e5e37d7

Please sign in to comment.