From fcb5e268ee8a51eff0f1467077161e6a7e3e2db7 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 10 Oct 2023 12:57:15 -0400 Subject: [PATCH] Ensure 'jit' is accessible in generated functions --- gpax/utils/priors.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/gpax/utils/priors.py b/gpax/utils/priors.py index 3671430..b24d8ca 100644 --- a/gpax/utils/priors.py +++ b/gpax/utils/priors.py @@ -13,6 +13,7 @@ from typing import Union, Dict, Type, List, Callable, Optional import numpyro +import jax import jax.numpy as jnp @@ -173,7 +174,7 @@ def set_fn(func: Callable) -> Callable: def set_kernel_fn(func: Callable, independent_vars: List[str] = ["X", "Z"], - decorators: Optional[List[str]] = ["@jit"], + jit_decorator: bool = True, docstring: Optional[str] = None) -> Callable: """ Transforms the given kernel function to use a params dictionary for its hyperparameters. @@ -182,7 +183,7 @@ def set_kernel_fn(func: Callable, Args: func (Callable): The kernel function to be transformed. independent_vars (List[str], optional): List of independent variable names in the function. Defaults to ["X", "Z"]. - decorators (Optional[List[str]], optional): List of decorators to be applied to the transformed function. Defaults to ["@jit"]. + jit_decorator (bool, optional): @jax.jit decorator to be applied to the transformed function. Defaults to True. docstring (Optional[str], optional): Docstring to be added to the transformed function. Defaults to None. Returns: @@ -195,8 +196,8 @@ def set_kernel_fn(func: Callable, params_names.remove(var) transformed_code = "" - for decorator in decorators: - transformed_code += decorator + "\n" + if jit_decorator: + transformed_code += "@jit" + "\n" additional_args = "noise: int = 0, jitter: float = 1e-6, **kwargs" transformed_code += f"def {func.__name__}({', '.join(independent_vars)}, params: Dict[str, jnp.ndarray], {additional_args}):\n" @@ -226,7 +227,7 @@ def set_kernel_fn(func: Callable, transformed_code += custom_code - local_namespace = {} + local_namespace = {"jit": jax.jit} exec(transformed_code, globals(), local_namespace) return local_namespace[func.__name__]