Skip to content

Commit

Permalink
Move priors out of utils
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Mar 18, 2024
1 parent 64bbec2 commit 8669f14
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 207 deletions.
3 changes: 2 additions & 1 deletion gpax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .__version__ import version as __version__
from . import priors
from . import utils
from . import kernels
from . import acquisition
Expand All @@ -7,6 +8,6 @@
vi_iBNN, viDKL, viGP, sPM, viMTDKL, VarNoiseGP, UIGP,
MeasuredNoiseGP, viSparseGP, BNN)

__all__ = ["utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
__all__ = ["priors", "utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
"viDKL", "iBNN", "vi_iBNN", "MultiTaskGP", "viMTDKL", "viGP", "sPM", "VarNoiseGP",
"UIGP", "MeasuredNoiseGP", "viSparseGP", "CoregGP", "BNN", "sample_next", "__version__"]
1 change: 1 addition & 0 deletions gpax/priors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .priors import *
139 changes: 2 additions & 137 deletions gpax/utils/priors.py → gpax/priors/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,16 @@
Utility functions for setting priors
Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
"""

import inspect
import re

from typing import Union, Dict, Type, List, Callable, Optional
from typing import Union, Dict, Type, Callable

import numpyro
import jax
import jax.numpy as jnp

from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt


def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
"""
Expand Down Expand Up @@ -183,137 +179,6 @@ def uniform_dist(low: float = None,
return numpyro.distributions.Uniform(low, high)


def set_fn(func: Callable) -> Callable:
"""
Transforms the given deterministic function to use a params dictionary
for its parameters, excluding the first one (assumed to be the dependent variable).
Args:
- func (Callable): The deterministic function to be transformed.
Returns:
- Callable: The transformed function where parameters are accessed
from a `params` dictionary.
"""
# Extract parameter names excluding the first one (assumed to be the dependent variable)
params_names = list(inspect.signature(func).parameters.keys())[1:]

# Create the transformed function definition
transformed_code = f"def {func.__name__}(x, params):\n"

# Retrieve the source code of the function and indent it to be a valid function body
source = inspect.getsource(func).split("\n", 1)[1]
source = " " + source.replace("\n", "\n ")

# Replace each parameter name with its dictionary lookup using regex
for name in params_names:
source = re.sub(rf'\b{name}\b', f'params["{name}"]', source)

# Combine to get the full source
transformed_code += source

# Define the transformed function in the local namespace
local_namespace = {}
exec(transformed_code, globals(), local_namespace)

# Return the transformed function
return local_namespace[func.__name__]


def set_kernel_fn(func: Callable,
independent_vars: List[str] = ["X", "Z"],
jit_decorator: bool = True,
docstring: Optional[str] = None) -> Callable:
"""
Transforms the given kernel function to use a params dictionary for its hyperparameters.
The resultant function will always add jitter before returning the computed kernel.
Args:
func (Callable): The kernel function to be transformed.
independent_vars (List[str], optional): List of independent variable names in the function. Defaults to ["X", "Z"].
jit_decorator (bool, optional): @jax.jit decorator to be applied to the transformed function. Defaults to True.
docstring (Optional[str], optional): Docstring to be added to the transformed function. Defaults to None.
Returns:
Callable: The transformed kernel function where hyperparameters are accessed from a `params` dictionary.
"""

# Extract parameter names excluding the independent variables
params_names = [k for k, v in inspect.signature(func).parameters.items() if v.default == v.empty]
for var in independent_vars:
params_names.remove(var)

transformed_code = ""
if jit_decorator:
transformed_code += "@jit" + "\n"

additional_args = "noise: int = 0, jitter: float = 1e-6, **kwargs"
transformed_code += f"def {func.__name__}({', '.join(independent_vars)}, params: Dict[str, jnp.ndarray], {additional_args}):\n"

if docstring:
transformed_code += ' """' + docstring + '"""\n'

source = inspect.getsource(func).split("\n", 1)[1]
lines = source.split("\n")

for idx, line in enumerate(lines):
# Convert all parameter names to their dictionary lookup throughout the function body
for name in params_names:
lines[idx] = re.sub(rf'\b{name}\b', f'params["{name}"]', lines[idx])

# Combine lines back and then split again by return
modified_source = '\n'.join(lines)
pre_return, return_statement = modified_source.split('return', 1)

# Append custom jitter code
custom_code = f" {pre_return.strip()}\n k = {return_statement.strip()}\n"
custom_code += """
if X.shape == Z.shape:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
"""

transformed_code += custom_code

local_namespace = {"jit": jax.jit}
exec(transformed_code, globals(), local_namespace)

return local_namespace[func.__name__]


def _set_noise_kernel_fn(func: Callable) -> Callable:
"""
Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses.
Args:
func (Callable): Original function.
Returns:
Callable: Modified function.
"""

# Get the source code of the function
source = inspect.getsource(func)

# Split the source into decorators, definition, and body
decorators_and_def, body = source.split("\n", 1)

# Replace all occurrences of params["k with params["k_noise in the body
modified_body = re.sub(r'params\["k', 'params["k_noise', body)

# Combine decorators, definition, and modified body
modified_source = f"{decorators_and_def}\n{modified_body}"

# Define local namespace including the jit decorator
local_namespace = {"jit": jax.jit}

# Execute the modified source to redefine the function in the provided namespace
exec(modified_source, globals(), local_namespace)

# Return the modified function
return local_namespace[func.__name__]


def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable:
"""
Generates a function that, when invoked, samples from normal or log-normal distributions
Expand Down
4 changes: 2 additions & 2 deletions gpax/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .utils import *
from .priors import *
from .priors import _set_noise_kernel_fn
from .fn import *
from .fn import _set_noise_kernel_fn
148 changes: 148 additions & 0 deletions gpax/utils/fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
fn.py
=====
Utilities for setting up custom mean and kernel functions
Created by Maxim Ziatdinov (email: [email protected])
"""

import inspect
import re

from typing import List, Callable, Optional

import jax

from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt


def set_fn(func: Callable) -> Callable:
"""
Transforms the given deterministic function to use a params dictionary
for its parameters, excluding the first one (assumed to be the dependent variable).
Args:
- func (Callable): The deterministic function to be transformed.
Returns:
- Callable: The transformed function where parameters are accessed
from a `params` dictionary.
"""
# Extract parameter names excluding the first one (assumed to be the dependent variable)
params_names = list(inspect.signature(func).parameters.keys())[1:]

# Create the transformed function definition
transformed_code = f"def {func.__name__}(x, params):\n"

# Retrieve the source code of the function and indent it to be a valid function body
source = inspect.getsource(func).split("\n", 1)[1]
source = " " + source.replace("\n", "\n ")

# Replace each parameter name with its dictionary lookup using regex
for name in params_names:
source = re.sub(rf'\b{name}\b', f'params["{name}"]', source)

# Combine to get the full source
transformed_code += source

# Define the transformed function in the local namespace
local_namespace = {}
exec(transformed_code, globals(), local_namespace)

# Return the transformed function
return local_namespace[func.__name__]


def set_kernel_fn(func: Callable,
independent_vars: List[str] = ["X", "Z"],
jit_decorator: bool = True,
docstring: Optional[str] = None) -> Callable:
"""
Transforms the given kernel function to use a params dictionary for its hyperparameters.
The resultant function will always add jitter before returning the computed kernel.
Args:
func (Callable): The kernel function to be transformed.
independent_vars (List[str], optional): List of independent variable names in the function. Defaults to ["X", "Z"].
jit_decorator (bool, optional): @jax.jit decorator to be applied to the transformed function. Defaults to True.
docstring (Optional[str], optional): Docstring to be added to the transformed function. Defaults to None.
Returns:
Callable: The transformed kernel function where hyperparameters are accessed from a `params` dictionary.
"""

# Extract parameter names excluding the independent variables
params_names = [k for k, v in inspect.signature(func).parameters.items() if v.default == v.empty]
for var in independent_vars:
params_names.remove(var)

transformed_code = ""
if jit_decorator:
transformed_code += "@jit" + "\n"

additional_args = "noise: int = 0, jitter: float = 1e-6, **kwargs"
transformed_code += f"def {func.__name__}({', '.join(independent_vars)}, params: Dict[str, jnp.ndarray], {additional_args}):\n"

if docstring:
transformed_code += ' """' + docstring + '"""\n'

source = inspect.getsource(func).split("\n", 1)[1]
lines = source.split("\n")

for idx, line in enumerate(lines):
# Convert all parameter names to their dictionary lookup throughout the function body
for name in params_names:
lines[idx] = re.sub(rf'\b{name}\b', f'params["{name}"]', lines[idx])

# Combine lines back and then split again by return
modified_source = '\n'.join(lines)
pre_return, return_statement = modified_source.split('return', 1)

# Append custom jitter code
custom_code = f" {pre_return.strip()}\n k = {return_statement.strip()}\n"
custom_code += """
if X.shape == Z.shape:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
"""

transformed_code += custom_code

local_namespace = {"jit": jax.jit}
exec(transformed_code, globals(), local_namespace)

return local_namespace[func.__name__]


def _set_noise_kernel_fn(func: Callable) -> Callable:
"""
Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses.
Args:
func (Callable): Original function.
Returns:
Callable: Modified function.
"""

# Get the source code of the function
source = inspect.getsource(func)

# Split the source into decorators, definition, and body
decorators_and_def, body = source.split("\n", 1)

# Replace all occurrences of params["k with params["k_noise in the body
modified_body = re.sub(r'params\["k', 'params["k_noise', body)

# Combine decorators, definition, and modified body
modified_source = f"{decorators_and_def}\n{modified_body}"

# Define local namespace including the jit decorator
local_namespace = {"jit": jax.jit}

# Execute the modified source to redefine the function in the provided namespace
exec(modified_source, globals(), local_namespace)

# Return the modified function
return local_namespace[func.__name__]
Loading

0 comments on commit 8669f14

Please sign in to comment.