Skip to content

Commit

Permalink
Run black on gp.py as a test
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewcarbone committed Sep 14, 2023
1 parent e694266 commit 6f33776
Showing 1 changed file with 113 additions and 95 deletions.
208 changes: 113 additions & 95 deletions gpax/models/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..kernels import get_kernel
from ..utils import split_in_batches

kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]
kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]

clear_cache = jax._src.dispatch.xla_primitive_callable.cache_clear

Expand Down Expand Up @@ -62,7 +62,7 @@ class ExactGP:
>>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, noiseless=True)
GP with custom noise prior
>>> gp_model = gpax.ExactGP(
>>> input_dim=1, kernel='RBF',
>>> noise_prior_dist = numpyro.distributions.HalfNormal(.1)
Expand All @@ -73,7 +73,7 @@ class ExactGP:
>>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, noiseless=True)
GP with custom probabilistic model as its mean function
>>> # Define a deterministic mean function
>>> mean_fn = lambda x, param: param["a"]*x + param["b"]
>>>
Expand All @@ -93,25 +93,34 @@ class ExactGP:
>>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, noiseless=True)
"""

def __init__(self, input_dim: int, kernel: Union[str, kernel_fn_type],
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior_dist: Optional[dist.Distribution] = None,
lengthscale_prior_dist: Optional[dist.Distribution] = None
) -> None:
def __init__(
self,
input_dim: int,
kernel: Union[str, kernel_fn_type],
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior_dist: Optional[dist.Distribution] = None,
lengthscale_prior_dist: Optional[dist.Distribution] = None,
) -> None:
clear_cache()
if noise_prior is not None:
warnings.warn("`noise_prior` is deprecated and will be removed in a future version. "
"Please use `noise_prior_dist` instead, which accepts an instance of a "
"numpyro.distributions Distribution object, e.g., `dist.HalfNormal(scale=0.1)`, "
"rather than a function that calls `numpyro.sample`.", FutureWarning)
warnings.warn(
"`noise_prior` is deprecated and will be removed in a future version. "
"Please use `noise_prior_dist` instead, which accepts an instance of a "
"numpyro.distributions Distribution object, e.g., `dist.HalfNormal(scale=0.1)`, "
"rather than a function that calls `numpyro.sample`.",
FutureWarning,
)
if kernel_prior is not None:
warnings.warn("`kernel_prior` will remain available for complex priors. However, for "
"modifying only the lengthscales, it is recommended to use `lengthscale_prior_dist` instead. "
"`lengthscale_prior_dist` accepts an instance of a numpyro.distributions Distribution object, "
"e.g., `dist.Gamma(2, 5)`, rather than a function that calls `numpyro.sample`.", UserWarning)
warnings.warn(
"`kernel_prior` will remain available for complex priors. However, for "
"modifying only the lengthscales, it is recommended to use `lengthscale_prior_dist` instead. "
"`lengthscale_prior_dist` accepts an instance of a numpyro.distributions Distribution object, "
"e.g., `dist.Gamma(2, 5)`, rather than a function that calls `numpyro.sample`.",
UserWarning,
)
self.kernel_dim = input_dim
self.kernel = get_kernel(kernel)
self.kernel_name = kernel if isinstance(kernel, str) else None
Expand All @@ -125,11 +134,7 @@ def __init__(self, input_dim: int, kernel: Union[str, kernel_fn_type],
self.y_train = None
self.mcmc = None

def model(self,
X: jnp.ndarray,
y: jnp.ndarray = None,
**kwargs: float
) -> None:
def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
"""GP probabilistic model with inputs X and targets y"""
# Initialize mean function at zeros
f_loc = jnp.zeros(X.shape[0])
Expand All @@ -150,26 +155,28 @@ def model(self,
args += [self.mean_fn_prior()]
f_loc += self.mean_fn(*args).squeeze()
# compute kernel
k = self.kernel(
X, X,
kernel_params,
noise,
**kwargs
)
k = self.kernel(X, X, kernel_params, noise, **kwargs)
# sample y according to the standard Gaussian process formula
numpyro.sample(
"y",
dist.MultivariateNormal(loc=f_loc, covariance_matrix=k),
obs=y,
)

def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
num_warmup: int = 2000, num_samples: int = 2000,
num_chains: int = 1, chain_method: str = 'sequential',
progress_bar: bool = True, print_summary: bool = True,
device: Type[jaxlib.xla_extension.Device] = None,
**kwargs: float
) -> None:
def fit(
self,
rng_key: jnp.array,
X: jnp.ndarray,
y: jnp.ndarray,
num_warmup: int = 2000,
num_samples: int = 2000,
num_chains: int = 1,
chain_method: str = "sequential",
progress_bar: bool = True,
print_summary: bool = True,
device: Type[jaxlib.xla_extension.Device] = None,
**kwargs: float
) -> None:
"""
Run Hamiltonian Monter Carlo to infer the GP parameters
Expand All @@ -185,7 +192,7 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
print_summary: print summary at the end of sampling
device:
optionally specify a cpu or gpu device on which to run the inference;
e.g., ``device=jax.devices("cpu")[0]``
e.g., ``device=jax.devices("cpu")[0]``
**jitter:
Small positive term added to the diagonal part of a covariance
matrix for numerical stability (Default: 1e-6)
Expand All @@ -206,7 +213,7 @@ def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
num_chains=num_chains,
chain_method=chain_method,
progress_bar=progress_bar,
jit_model_args=False
jit_model_args=False,
)
self.mcmc.run(rng_key, X, y, **kwargs)
if print_summary:
Expand All @@ -228,28 +235,25 @@ def _sample_kernel_params(self, output_scale=True) -> Dict[str, jnp.ndarray]:
length_dist = self.lengthscale_prior_dist
else:
length_dist = dist.LogNormal(0.0, 1.0)
with numpyro.plate('ard', self.kernel_dim): # allows using ARD kernel for kernel_dim > 1
with numpyro.plate("ard", self.kernel_dim): # allows using ARD kernel for kernel_dim > 1
length = numpyro.sample("k_length", length_dist)
if output_scale:
scale = numpyro.sample("k_scale", dist.LogNormal(0.0, 1.0))
else:
scale = numpyro.deterministic("k_scale", jnp.array(1.0))
if self.kernel_name == 'Periodic':
if self.kernel_name == "Periodic":
period = numpyro.sample("period", dist.LogNormal(0.0, 1.0))
kernel_params = {
"k_length": length, "k_scale": scale,
"period": period if self.kernel_name == "Periodic" else None}
kernel_params = {"k_length": length, "k_scale": scale, "period": period if self.kernel_name == "Periodic" else None}
return kernel_params

def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]:
"""Get posterior samples (after running the MCMC chains)"""
return self.mcmc.get_samples(group_by_chain=chain_dim)

#@partial(jit, static_argnames='self')
def get_mvn_posterior(self,
X_new: jnp.ndarray, params: Dict[str, jnp.ndarray],
noiseless: bool = False, **kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
# @partial(jit, static_argnames='self')
def get_mvn_posterior(
self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
for a single sample of GP parameters
Expand All @@ -273,30 +277,38 @@ def get_mvn_posterior(self,
mean += self.mean_fn(*args).squeeze()
return mean, cov

def _predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray], n: int, noiseless: bool = False,
**kwargs: float) -> Tuple[jnp.ndarray, jnp.ndarray]:
def _predict(
self,
rng_key: jnp.ndarray,
X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
n: int,
noiseless: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Prediction with a single sample of GP parameters"""
# Get the predictive mean and covariance
y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, **kwargs)
# draw samples from the posterior predictive for a given set of parameters
y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,))
return y_mean, y_sampled

def _predict_in_batches(self, rng_key: jnp.ndarray,
X_new: jnp.ndarray, batch_size: int = 100,
batch_dim: int = 0,
samples: Optional[Dict[str, jnp.ndarray]] = None,
n: int = 1, filter_nans: bool = False,
predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:

def _predict_in_batches(
self,
rng_key: jnp.ndarray,
X_new: jnp.ndarray,
batch_size: int = 100,
batch_dim: int = 0,
samples: Optional[Dict[str, jnp.ndarray]] = None,
n: int = 1,
filter_nans: bool = False,
predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
if predict_fn is None:
predict_fn = lambda xi: self.predict(
rng_key, xi, samples, n, filter_nans, noiseless, device, **kwargs)
predict_fn = lambda xi: self.predict(rng_key, xi, samples, n, filter_nans, noiseless, device, **kwargs)

def predict_batch(Xi):
out1, out2 = predict_fn(Xi)
Expand All @@ -311,33 +323,43 @@ def predict_batch(Xi):
y_out2.append(out2)
return y_out1, y_out2

def predict_in_batches(self, rng_key: jnp.ndarray,
X_new: jnp.ndarray, batch_size: int = 100,
samples: Optional[Dict[str, jnp.ndarray]] = None,
n: int = 1, filter_nans: bool = False,
predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
def predict_in_batches(
self,
rng_key: jnp.ndarray,
X_new: jnp.ndarray,
batch_size: int = 100,
samples: Optional[Dict[str, jnp.ndarray]] = None,
n: int = 1,
filter_nans: bool = False,
predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Make prediction at X_new with sampled GP parameters
by spitting the input array into chunks ("batches") and running
predict_fn (defaults to self.predict) on each of them one-by-one
to avoid a memory overflow
"""
y_pred, y_sampled = self._predict_in_batches(
rng_key, X_new, batch_size, 0, samples, n,
filter_nans, predict_fn, noiseless, device, **kwargs)
rng_key, X_new, batch_size, 0, samples, n, filter_nans, predict_fn, noiseless, device, **kwargs
)
y_pred = jnp.concatenate(y_pred, 0)
y_sampled = jnp.concatenate(y_sampled, -1)
return y_pred, y_sampled

def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
samples: Optional[Dict[str, jnp.ndarray]] = None,
n: int = 1, filter_nans: bool = False, noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None, **kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
def predict(
self,
rng_key: jnp.ndarray,
X_new: jnp.ndarray,
samples: Optional[Dict[str, jnp.ndarray]] = None,
n: int = 1,
filter_nans: bool = False,
noiseless: bool = False,
device: Type[jaxlib.xla_extension.Device] = None,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Make prediction at X_new points using posterior samples for GP parameters
Expand Down Expand Up @@ -370,38 +392,34 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
samples = jax.device_put(samples, device)
num_samples = samples["noise"].shape[0]
vmap_args = (jra.split(rng_key, num_samples), samples)
predictive = jax.vmap(
lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, **kwargs))
predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, **kwargs))
y_means, y_sampled = predictive(vmap_args)
if filter_nans:
y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()]
y_sampled = jnp.array(y_sampled_)
return y_means.mean(0), y_sampled

def sample_from_prior(self, rng_key: jnp.ndarray,
X: jnp.ndarray, num_samples: int = 10):
def sample_from_prior(self, rng_key: jnp.ndarray, X: jnp.ndarray, num_samples: int = 10):
"""
Samples from prior predictive distribution at X
"""
X = self._set_data(X)
prior_predictive = Predictive(self.model, num_samples=num_samples)
samples = prior_predictive(rng_key, X)
return samples['y']
return samples["y"]

def _set_data(self,
X: jnp.ndarray,
y: Optional[jnp.ndarray] = None
) -> Union[Tuple[jnp.ndarray], jnp.ndarray]:
def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None) -> Union[Tuple[jnp.ndarray], jnp.ndarray]:
X = X if X.ndim > 1 else X[:, None]
if y is not None:
return X, y.squeeze()
return X

def _set_training_data(self,
X_train_new: jnp.ndarray = None,
y_train_new: jnp.ndarray = None,
device: Type[jaxlib.xla_extension.Device] = None
) -> None:
def _set_training_data(
self,
X_train_new: jnp.ndarray = None,
y_train_new: jnp.ndarray = None,
device: Type[jaxlib.xla_extension.Device] = None,
) -> None:
X_train = self.X_train if X_train_new is None else X_train_new
y_train = self.y_train if y_train_new is None else y_train_new
if device:
Expand Down

0 comments on commit 6f33776

Please sign in to comment.