Skip to content

Commit

Permalink
Update hskgp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax authored Nov 16, 2023
1 parent 3e3e893 commit dffaaab
Showing 1 changed file with 16 additions and 36 deletions.
52 changes: 16 additions & 36 deletions gpax/models/hskgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ class VarNoiseGP(ExactGP):
Optional priors over noise mean function
noise_lengthscale_prior_dist:
Optional custom prior distribution over noise kernel lengthscale. Defaults to LogNormal(0, 1).
Examples:
Examples:
Use two different kernels with default priors for main and noise processes
Expand Down Expand Up @@ -162,7 +161,7 @@ def _sample_noise_kernel_params(self) -> Dict[str, jnp.ndarray]:
return {"k_noise_length": noise_length, "k_noise_scale": noise_scale}

def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], *arg, **kwargs
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], *args, **kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
Expand All @@ -186,17 +185,6 @@ def get_mvn_posterior(
mean += self.mean_fn(*args).squeeze()

# Noise GP part
predicted_log_var = self.get_noise_mvn_posterior(X_new, params, **kwargs)
predicted_var = jnp.exp(predicted_log_var)

# Return the main GP's predictive mean and combined (main + noise) covariance matrix
return mean, cov + jnp.diag(predicted_var)

def get_noise_mvn_posterior(self,
X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
**kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray]:
# Compute noise kernel matrices
k_pX_noise = self.noise_kernel(X_new, self.X_train, params, jitter=0.0)
k_XX_noise = self.noise_kernel(self.X_train, self.X_train, params, 0, **kwargs)
Expand All @@ -210,30 +198,22 @@ def get_noise_mvn_posterior(self,
if self.noise_mean_fn is not None:
args = [X_new, params] if self.noise_mean_fn_prior else [X_new]
predicted_log_var += jnp.log(self.noise_mean_fn(*args)).squeeze()

#k_pp_noise = self.noise_kernel(X_new, X_new, params, 0, **kwargs)
#cov_noise = k_pp_noise - jnp.matmul(k_pX_noise, jnp.matmul(K_xx_noise_inv, jnp.transpose(k_pX_noise)))

return predicted_log_var

# def get_data_var_samples(self):
# """Returns samples with inferred (training) data variance - aka noise"""
# samples = self.mcmc.get_samples()
# log_var = samples["log_var"]
# if self.noise_mean_fn is not None:
# if self.noise_mean_fn_prior is not None:
# mean_ = jax.vmap(self.noise_mean_fn, in_axes=(None, 0))(self.X_train.squeeze(), samples)
# else:
# mean_ = self.noise_mean_fn(self.X_train.squeeze())
# log_var += jnp.log(mean_)
# return jnp.exp(samples["log_var"])

def get_data_var_samples(self, **kwargs):
predicted_noise_variance = jnp.exp(predicted_log_var)

# Return the main GP's predictive mean and combined (main + noise) covariance matrix
return mean, cov + jnp.diag(predicted_noise_variance)

def get_data_var_samples(self):
"""Returns samples with inferred (training) data variance - aka noise"""
predict_ = lambda p: self.get_noise_mvn_posterior(self.X_train, p, **kwargs)
samples = self.mcmc.get_samples()
predicted_log_var = jax.vmap(predict_)(samples)
return jnp.exp(predicted_log_var)
log_var = samples["log_var"]
if self.noise_mean_fn is not None:
if self.noise_mean_fn_prior is not None:
mean_ = jax.vmap(self.noise_mean_fn, in_axes=(None, 0))(self.X_train.squeeze(), samples)
else:
mean_ = self.noise_mean_fn(self.X_train.squeeze())
log_var += jnp.log(mean_)
return jnp.exp(samples["log_var"])

def _print_summary(self):
samples = self.get_samples(1)
Expand Down

0 comments on commit dffaaab

Please sign in to comment.