Skip to content

Commit

Permalink
Ensure 'jit' is accessible in generated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Oct 10, 2023
1 parent 0fbd1ad commit fcb5e26
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions gpax/utils/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Union, Dict, Type, List, Callable, Optional

import numpyro
import jax
import jax.numpy as jnp


Expand Down Expand Up @@ -173,7 +174,7 @@ def set_fn(func: Callable) -> Callable:

def set_kernel_fn(func: Callable,
independent_vars: List[str] = ["X", "Z"],
decorators: Optional[List[str]] = ["@jit"],
jit_decorator: bool = True,
docstring: Optional[str] = None) -> Callable:
"""
Transforms the given kernel function to use a params dictionary for its hyperparameters.
Expand All @@ -182,7 +183,7 @@ def set_kernel_fn(func: Callable,
Args:
func (Callable): The kernel function to be transformed.
independent_vars (List[str], optional): List of independent variable names in the function. Defaults to ["X", "Z"].
decorators (Optional[List[str]], optional): List of decorators to be applied to the transformed function. Defaults to ["@jit"].
jit_decorator (bool, optional): @jax.jit decorator to be applied to the transformed function. Defaults to True.
docstring (Optional[str], optional): Docstring to be added to the transformed function. Defaults to None.
Returns:
Expand All @@ -195,8 +196,8 @@ def set_kernel_fn(func: Callable,
params_names.remove(var)

transformed_code = ""
for decorator in decorators:
transformed_code += decorator + "\n"
if jit_decorator:
transformed_code += "@jit" + "\n"

additional_args = "noise: int = 0, jitter: float = 1e-6, **kwargs"
transformed_code += f"def {func.__name__}({', '.join(independent_vars)}, params: Dict[str, jnp.ndarray], {additional_args}):\n"
Expand Down Expand Up @@ -226,7 +227,7 @@ def set_kernel_fn(func: Callable,

transformed_code += custom_code

local_namespace = {}
local_namespace = {"jit": jax.jit}
exec(transformed_code, globals(), local_namespace)

return local_namespace[func.__name__]
Expand Down

0 comments on commit fcb5e26

Please sign in to comment.