Skip to content

Commit

Permalink
utility to convert regular kernel to noise kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Oct 10, 2023
1 parent 823d99e commit 851975e
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion gpax/utils/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import jax
import jax.numpy as jnp

from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt


def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
"""
Expand Down Expand Up @@ -227,12 +229,45 @@ def set_kernel_fn(func: Callable,

transformed_code += custom_code

local_namespace = {"jit": jax.jit}
local_namespace = {"jit": jax.jit}
exec(transformed_code, globals(), local_namespace)

return local_namespace[func.__name__]


def set_noise_kernel_fn(func: Callable) -> Callable:
"""
Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses.
Args:
func (Callable): Original function.
Returns:
Callable: Modified function.
"""

# Get the source code of the function
source = inspect.getsource(func)

# Split the source into decorators, definition, and body
decorators_and_def, body = source.split("\n", 1)

# Replace all occurrences of params["k with params["k_noise in the body
modified_body = re.sub(r'params\["k', 'params["k_noise', body)

# Combine decorators, definition, and modified body
modified_source = f"{decorators_and_def}\n{modified_body}"

# Define local namespace including the jit decorator
local_namespace = {"jit": jax.jit}

# Execute the modified source to redefine the function in the provided namespace
exec(modified_source, globals(), local_namespace)

# Return the modified function
return local_namespace[func.__name__]


def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable:
"""
Generates a function that, when invoked, samples from normal or log-normal distributions
Expand Down

0 comments on commit 851975e

Please sign in to comment.