From 851975e20780ea4e3356ddac2b4d66826786999b Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 10 Oct 2023 17:53:43 -0400 Subject: [PATCH] utility to convert regular kernel to noise kernel --- gpax/utils/priors.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/gpax/utils/priors.py b/gpax/utils/priors.py index b24d8ca..9e5e356 100644 --- a/gpax/utils/priors.py +++ b/gpax/utils/priors.py @@ -16,6 +16,8 @@ import jax import jax.numpy as jnp +from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt + def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0): """ @@ -227,12 +229,45 @@ def set_kernel_fn(func: Callable, transformed_code += custom_code - local_namespace = {"jit": jax.jit} + local_namespace = {"jit": jax.jit} exec(transformed_code, globals(), local_namespace) return local_namespace[func.__name__] +def set_noise_kernel_fn(func: Callable) -> Callable: + """ + Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses. + + Args: + func (Callable): Original function. + + Returns: + Callable: Modified function. + """ + + # Get the source code of the function + source = inspect.getsource(func) + + # Split the source into decorators, definition, and body + decorators_and_def, body = source.split("\n", 1) + + # Replace all occurrences of params["k with params["k_noise in the body + modified_body = re.sub(r'params\["k', 'params["k_noise', body) + + # Combine decorators, definition, and modified body + modified_source = f"{decorators_and_def}\n{modified_body}" + + # Define local namespace including the jit decorator + local_namespace = {"jit": jax.jit} + + # Execute the modified source to redefine the function in the provided namespace + exec(modified_source, globals(), local_namespace) + + # Return the modified function + return local_namespace[func.__name__] + + def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable: """ Generates a function that, when invoked, samples from normal or log-normal distributions