Skip to content

Commit

Permalink
Add option for using a custom noise kernel prior
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Oct 10, 2023
1 parent 851975e commit cb1fc81
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion gpax/models/hskgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
lengthscale_prior_dist: Optional[dist.Distribution] = None,
noise_mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
noise_mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
Expand All @@ -68,6 +69,7 @@ def __init__(

self.noise_mean_fn = noise_mean_fn
self.noise_mean_fn_prior = noise_mean_fn_prior
self.noise_kernel_prior = noise_kernel_prior
self.noise_lengthscale_prior_dist = noise_lengthscale_prior_dist

def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
Expand All @@ -77,7 +79,10 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
noise_f_loc = jnp.zeros(X.shape[0])

# Sample noise kernel parameters
noise_kernel_params = self._sample_noise_kernel_params()
if self.noise_kernel_prior:
noise_kernel_params = self.noise_kernel_prior()
else:
noise_kernel_params = self._sample_noise_kernel_params()
# Add noise prior mean function (if any)
if self.noise_mean_fn is not None:
args = [X]
Expand Down

0 comments on commit cb1fc81

Please sign in to comment.