diff --git a/gpax/acquisition/__init__.py b/gpax/acquisition/__init__.py index a12e435..f2a3a48 100644 --- a/gpax/acquisition/__init__.py +++ b/gpax/acquisition/__init__.py @@ -1 +1,4 @@ -from .acquisition import * \ No newline at end of file +from .acquisition import UCB, EI, POI, UE, Thompson +from .batch_acquisition import qEI, qPOI, qUCB + +__all__ = ["UCB", "EI", "POI", "UE", "Thompson", "qEI", "qPOI", "qUCB"] diff --git a/gpax/acquisition/acquisition.py b/gpax/acquisition/acquisition.py index 8ec5e3f..676f029 100644 --- a/gpax/acquisition/acquisition.py +++ b/gpax/acquisition/acquisition.py @@ -7,135 +7,19 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Type, Optional, Dict, Callable, Any +from typing import Type, Optional, Callable, Any import jax.numpy as jnp import jax.random as jra from jax import vmap import numpy as onp -import numpyro.distributions as dist from ..models.gp import ExactGP from ..utils import random_sample_dict +from .base_acq import ei, ucb, poi from .penalties import compute_penalty -def ei(model: Type[ExactGP], - X: jnp.ndarray, - sample: Dict[str, jnp.ndarray], - maximize: bool = False, - noiseless: bool = False, - **kwargs) -> jnp.ndarray: - r""" - Expected Improvement - - Args: - model: trained model - X: new inputs with shape (N, D), where D is a feature dimension - sample: a single sample with model parameters - maximize: If True, assumes that BO is solving maximization problem - noiseless: - Noise-free prediction. It is set to False by default as new/unseen data is assumed - to follow the same distribution as the training data. Hence, since we introduce a model noise - for the training data, we also want to include that noise in our prediction. - **jitter: - Small positive term added to the diagonal part of a covariance - matrix for numerical stability (Default: 1e-6) - """ - if not isinstance(sample, (tuple, list)): - sample = (sample,) - # Get predictive mean and covariance for a single sample with kernel parameters - pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) - # Compute standard deviation - sigma = jnp.sqrt(cov.diagonal()) - # Standard EI computation - best_f = pred.max() if maximize else pred.min() - u = (pred - best_f) / sigma - if not maximize: - u = -u - normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) - ucdf = normal.cdf(u) - updf = jnp.exp(normal.log_prob(u)) - acq = sigma * (updf + u * ucdf) - return acq - - -def ucb(model: Type[ExactGP], - X: jnp.ndarray, - sample: Dict[str, jnp.ndarray], - beta: float = 0.25, - maximize: bool = False, - noiseless: bool = False, - **kwargs) -> jnp.ndarray: - r""" - Upper confidence bound - - Args: - model: trained model - X: new inputs with shape (N, D), where D is a feature dimension - sample: a single sample with model parameters - beta: coefficient balancing exploration-exploitation trade-off - maximize: If True, assumes that BO is solving maximization problem - noiseless: - Noise-free prediction. It is set to False by default as new/unseen data is assumed - to follow the same distribution as the training data. Hence, since we introduce a model noise - for the training data, we also want to include that noise in our prediction. - **jitter: - Small positive term added to the diagonal part of a covariance - matrix for numerical stability (Default: 1e-6) - """ - if not isinstance(sample, (tuple, list)): - sample = (sample,) - # Get predictive mean and covariance for a single sample with kernel parameters - mean, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) - var = cov.diagonal() - delta = jnp.sqrt(beta * var) - if maximize: - acq = mean + delta - else: - acq = delta - mean # we return a negative acq for argmax in BO - return acq - - -def poi(model: Type[ExactGP], - X: jnp.ndarray, - sample: Dict[str, jnp.ndarray], - xi: float = 0.01, - maximize: bool = False, - noiseless: bool = False, - **kwargs) -> jnp.ndarray: - r""" - Probability of Improvement - - Args: - model: trained model - X: new inputs with shape (N, D), where D is a feature dimension - sample: a single sample with model parameters - xi: Exploration-exploitation trade-off parameter (Defaults to 0.01) - maximize: If True, assumes that BO is solving maximization problem - noiseless: - Noise-free prediction. It is set to False by default as new/unseen data is assumed - to follow the same distribution as the training data. Hence, since we introduce a model noise - for the training data, we also want to include that noise in our prediction. - **jitter: - Small positive term added to the diagonal part of a covariance - matrix for numerical stability (Default: 1e-6) - """ - if not isinstance(sample, (tuple, list)): - sample = (sample,) - # Get predictive mean and covariance for a single sample with kernel parameters - pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) - # Compute standard deviation - sigma = jnp.sqrt(cov.diagonal()) - # Standard computation of poi - best_f = pred.max() if maximize else pred.min() - u = (pred - best_f - xi) / sigma - if not maximize: - u = -u - normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) - return normal.cdf(u) - - def compute_acquisition( model: Type[ExactGP], X: jnp.ndarray, @@ -192,9 +76,43 @@ def compute_acquisition( return acq -def compute_batched_acquisition(): - # TBA - return +def compute_batch_acquisition(acquisition_type: Callable, + model: Type[ExactGP], + X: jnp.ndarray, + *acq_args, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + + """ + Batch-mode acquisition function fo a given type + """ + + if model.mcmc is None: + raise ValueError("The model needs to be fully Bayesian") + + samples = random_sample_dict(model.get_samples(), subsample_size) + f = vmap(acquisition_type, in_axes=(None, None, 0) + (None,) * len(acq_args)) + + if not maximize_distance: + acq = f(model, X, samples, *acq_args, **kwargs) + else: + X_ = jnp.array(indices) if indices is not None else jnp.array(X) + acq_all, dist_all = [], [] + + for _ in range(n_evals): + acq = f(model, X_, samples, *acq_args, **kwargs) + points = acq.argmax(-1) + d = jnp.linalg.norm(points).mean() + acq_all.append(acq) + dist_all.append(d) + + idx = jnp.array(dist_all).argmax() + acq = acq_all[idx] + + return acq def EI(rng_key: jnp.ndarray, model: Type[ExactGP], @@ -257,22 +175,6 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], grid_indices=grid_indices, penalty_factor=penalty_factor, **kwargs) - # if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): - # raise ValueError("Please provide an array of recently visited points") - # X = X[:, None] if X.ndim < 2 else X - # samples = model.get_samples() - # if model.mcmc is None: - # acq = ei(model, X, samples, maximize, noiseless, **kwargs) - # else: - # f = vmap(ei, in_axes=(None, None, 0, None, None)) - # acq = f(model, X, samples, maximize, noiseless, **kwargs) - # acq = acq.mean(0) - # if penalty: - # X_ = grid_indices if grid_indices is not None else X - # penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) - # acq -= penalties - # return acq - def POI(rng_key: jnp.ndarray, model: Type[ExactGP], @@ -400,58 +302,6 @@ def UCB(rng_key: jnp.ndarray, penalty=penalty, recent_points=recent_points, grid_indices=grid_indices, penalty_factor=penalty_factor, **kwargs) - # if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): - # raise ValueError("Please provide an array of recently visited points") - # X = X[:, None] if X.ndim < 2 else X - # samples = model.get_samples() - # if model.mcmc is None: - # acq = ucb(model, X, samples, beta, maximize, noiseless, **kwargs) - # else: - # f = vmap(ucb, in_axes=(None, None, 0, None, None)) - # acq = f(model, X, samples, maximize, noiseless, **kwargs) - # acq = acq.mean(0) - # if penalty: - # X_ = grid_indices if grid_indices is not None else X - # penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) - # acq -= penalties - # return acq - - - -def qEI(rng_key: jnp.ndarray, - model: Type[ExactGP], - X: jnp.ndarray, - maximize: bool = False, n: int = 1, - noiseless: bool = False, - maximize_distance: bool = False, - n_evals: int = 1, - subsample_size: int = 1, - indices: Optional[jnp.ndarray] = None, - **kwargs) -> jnp.ndarray: - - if model.mcmc is None: - raise ValueError("qEI works only with fully Bayesian models") - - if not maximize_distance: - samples = random_sample_dict(model.get_samples(), subsample_size) - f = vmap(ei, in_axes=(None, None, 0, None, None)) - acq = f(model, X, samples, maximize, noiseless, **kwargs) - - else: # draws samples multiple times and selects the ones where maxima are farthest apart from each other - X_ = jnp.array(indices) if indices is not None else jnp.array(X) - acq_all, dist_all = [], [] - for _ in range(n_evals): - samples = random_sample_dict(model.get_samples(), subsample_size) - f = vmap(ei, in_axes=(None, None, 0, None, None)) - acq = f(model, X_, samples, maximize, noiseless, **kwargs) # (subsample_size, len(X)) - points = acq.argmax(-1) - d = jnp.linalg.norm(points).mean() - acq_all.append(acq) - dist_all.append(d) - idx = jnp.array(dist_all).argmax() - acq = acq_all[idx] - - return acq def UE(rng_key: jnp.ndarray, @@ -555,83 +405,4 @@ def Thompson(rng_key: jnp.ndarray, else: _, tsample = model.sample_from_posterior( rng_key, X, n=1, noiseless=noiseless, **kwargs) - return tsample - - -# def qUCB(rng_key: jnp.ndarray, model: Type[ExactGP], -# X: jnp.ndarray, indices: Optional[jnp.ndarray] = None, -# qbatch_size: int = 4, alpha: float = 1.0, beta: float = .25, -# maximize: bool = True, n: int = 500, -# n_restarts: int = 20, noiseless: bool = False, -# **kwargs) -> jnp.ndarray: -# """ -# The acquisition function defined as alpha * mu + sqrt(beta) * sigma -# that can output a "batch" of next points to evaluate. It takes advantage of -# the fact that in MCMC-based GP or DKL we obtain a separate multivariate -# normal posterior for each set of sampled kernel hyperparameters. - -# Args: -# rng_key: random number generator key -# model: ExactGP or DKL type of model -# X: input array -# indices: indices of data points in X array. For example, if -# each data point is an image patch, the indices should -# correspond to their (x, y) coordinates in the original image. -# qbatch_size: desired number of sampled points (default: 4) -# alpha: coefficient before mean prediction term (default: 1.0) -# beta: coefficient before variance term (default: 0.25) -# maximize: sign of variance term (+/- if True/False) -# n: number of draws from each multivariate normal posterior -# n_restarts: number of restarts to find a batch of maximally -# separated points to evaluate next -# noiseless: noise-free prediction for new/test data (default: False) - -# Returns: -# Computed acquisition function with qbatch x features -# or task x qbatch x features dimensions -# """ -# if model.mcmc is None: -# raise NotImplementedError( -# "Currently supports only ExactGP and DKL with MCMC inference") -# dist_all, obj_all = [], [] -# X_ = jnp.array(indices) if indices is not None else jnp.array(X) -# for _ in range(n_restarts): -# y_sampled = obtain_samples( -# rng_key, model, X, qbatch_size, n, noiseless, **kwargs) -# mean, var = y_sampled.mean(1), y_sampled.var(1) -# delta = jnp.sqrt(beta * var) -# if maximize: -# obj = alpha * mean + delta -# points = X_[obj.argmax(-1)] -# else: -# obj = alpha * mean - delta -# points = X_[obj.argmin(-1)] -# d = jnp.linalg.norm(points, axis=-1).mean(0) -# dist_all.append(d) -# obj_all.append(obj) -# idx = jnp.array(dist_all).argmax(0) -# if idx.ndim > 0: -# obj_all = jnp.array(obj_all) -# return jnp.array([obj_all[j,:,i] for i, j in enumerate(idx)]) -# return obj_all[idx] - - -# def obtain_samples(rng_key: jnp.ndarray, model: Type[ExactGP], -# X: jnp.ndarray, qbatch_size: int = 4, -# n: int = 500, noiseless: bool = False, -# **kwargs) -> jnp.ndarray: -# xbatch_size = kwargs.get("xbatch_size", 100) -# posterior_samples = model.get_samples() -# idx = onp.arange(0, len(posterior_samples["k_length"])) -# onp.random.shuffle(idx) -# idx = idx[:qbatch_size] -# samples = {k: v[idx] for (k, v) in posterior_samples.items()} -# if X.shape[0] > xbatch_size: -# _, y_sampled = model.predict( -# rng_key, X, samples, n, -# noiseless=noiseless, **kwargs) -# else: -# _, y_sampled = model.predict_in_batches( -# rng_key, X, xbatch_size, samples, n, -# noiseless=noiseless, **kwargs) -# return y_sampled + return tsample \ No newline at end of file diff --git a/gpax/acquisition/base_acq.py b/gpax/acquisition/base_acq.py new file mode 100644 index 0000000..036d47d --- /dev/null +++ b/gpax/acquisition/base_acq.py @@ -0,0 +1,132 @@ +""" +base_acq.py +============== + +Base acquisition functions + +Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) +""" + +from typing import Type, Dict + +import jax.numpy as jnp +import numpyro.distributions as dist + +from ..models.gp import ExactGP + + +def ei(model: Type[ExactGP], + X: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + maximize: bool = False, + noiseless: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Expected Improvement + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if not isinstance(sample, (tuple, list)): + sample = (sample,) + # Get predictive mean and covariance for a single sample with kernel parameters + pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) + # Compute standard deviation + sigma = jnp.sqrt(cov.diagonal()) + # Standard EI computation + best_f = pred.max() if maximize else pred.min() + u = (pred - best_f) / sigma + if not maximize: + u = -u + normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) + ucdf = normal.cdf(u) + updf = jnp.exp(normal.log_prob(u)) + acq = sigma * (updf + u * ucdf) + return acq + + +def ucb(model: Type[ExactGP], + X: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + beta: float = 0.25, + maximize: bool = False, + noiseless: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Upper confidence bound + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + beta: coefficient balancing exploration-exploitation trade-off + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if not isinstance(sample, (tuple, list)): + sample = (sample,) + # Get predictive mean and covariance for a single sample with kernel parameters + mean, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) + var = cov.diagonal() + # Standard UCB derivation + delta = jnp.sqrt(beta * var) + if maximize: + acq = mean + delta + else: + acq = delta - mean # we return a negative acq for argmax in BO + return acq + + +def poi(model: Type[ExactGP], + X: jnp.ndarray, + sample: Dict[str, jnp.ndarray], + xi: float = 0.01, + maximize: bool = False, + noiseless: bool = False, + **kwargs) -> jnp.ndarray: + r""" + Probability of Improvement + + Args: + model: trained model + X: new inputs with shape (N, D), where D is a feature dimension + sample: a single sample with model parameters + xi: Exploration-exploitation trade-off parameter (Defaults to 0.01) + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + **jitter: + Small positive term added to the diagonal part of a covariance + matrix for numerical stability (Default: 1e-6) + """ + if not isinstance(sample, (tuple, list)): + sample = (sample,) + # Get predictive mean and covariance for a single sample with kernel parameters + pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) + # Compute standard deviation + sigma = jnp.sqrt(cov.diagonal()) + # Standard computation of poi + best_f = pred.max() if maximize else pred.min() + u = (pred - best_f - xi) / sigma + if not maximize: + u = -u + normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) + return normal.cdf(u) diff --git a/gpax/acquisition/batch_acquisition.py b/gpax/acquisition/batch_acquisition.py new file mode 100644 index 0000000..39ded82 --- /dev/null +++ b/gpax/acquisition/batch_acquisition.py @@ -0,0 +1,182 @@ +""" +batch_acquisition.py +============== + +Batch-mode acquisition functions + +Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) +""" + +from typing import Type, Optional, Callable + +import jax.numpy as jnp +from jax import vmap + +from ..models.gp import ExactGP +from ..utils import random_sample_dict +from .base_acq import ei, ucb, poi + + +def compute_batch_acquisition(acquisition_type: Callable, + model: Type[ExactGP], + X: jnp.ndarray, + *acq_args, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Computes batch-mode acquisition function of a given type + """ + if model.mcmc is None: + raise ValueError("The model needs to be fully Bayesian") + + samples = random_sample_dict(model.get_samples(), subsample_size) + f = vmap(acquisition_type, in_axes=(None, None, 0) + (None,) * len(acq_args)) + + if not maximize_distance: + acq = f(model, X, samples, *acq_args, **kwargs) + else: + X_ = jnp.array(indices) if indices is not None else jnp.array(X) + acq_all, dist_all = [], [] + + for _ in range(n_evals): + acq = f(model, X_, samples, *acq_args, **kwargs) + points = acq.argmax(-1) + d = jnp.linalg.norm(points).mean() + acq_all.append(acq) + dist_all.append(d) + + idx = jnp.array(dist_all).argmax() + acq = acq_all[idx] + + return acq + + +def qEI(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + maximize: bool = False, + noiseless: bool = False, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Batch-mode Expected Improvement + + Args: + model: trained model + X: new inputs + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + maximize_distance: + Selects a subsample with a maximum distance between acq.argmax() points + n_evals: + Number of evaluations (how many times a ramdom subsample is drawn) + when maximizing distance between maxima of different EIs in a batch. + subsample_size: + Size of the subsample from the GP model's MCMC samples. + indices: + Indices of the input points. + + Returns: + The computed batch Expected Improvement values at the provided input points X. + """ + + return compute_batch_acquisition( + ei, rng_key, model, X, maximize, noiseless, + maximize_distance=maximize_distance, + n_evals=n_evals, subsample_size=subsample_size, + indices=indices, **kwargs) + + +def qUCB(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + beta: float = 0.25, + maximize: bool = False, + noiseless: bool = False, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Batch-mode Upper Confidence Bound + + Args: + model: trained model + X: new inputs + beta: the exploration-exploitation trade-off + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + maximize_distance: + Selects a subsample with a maximum distance between acq.argmax() points + n_evals: + Number of evaluations (how many times a ramdom subsample is drawn) + when maximizing distance between maxima of different EIs in a batch. + subsample_size: + Size of the subsample from the GP model's MCMC samples. + indices: + Indices of the input points. + + Returns: + The computed batch Upper Confidence Bound values at the provided input points X. + """ + + return compute_batch_acquisition( + ucb, rng_key, model, X, beta, maximize, noiseless, + maximize_distance=maximize_distance, + n_evals=n_evals, subsample_size=subsample_size, + indices=indices, **kwargs) + + +def qPOI(rng_key: jnp.ndarray, + model: Type[ExactGP], + X: jnp.ndarray, + xi: float = .001, + maximize: bool = False, + noiseless: bool = False, + maximize_distance: bool = False, + n_evals: int = 1, + subsample_size: int = 1, + indices: Optional[jnp.ndarray] = None, + **kwargs) -> jnp.ndarray: + """ + Batch-mode Probability of Improvement + + Args: + model: trained model + X: new inputs + xi: the exploration-exploitation trade-off + maximize: If True, assumes that BO is solving maximization problem + noiseless: + Noise-free prediction. It is set to False by default as new/unseen data is assumed + to follow the same distribution as the training data. Hence, since we introduce a model noise + for the training data, we also want to include that noise in our prediction. + maximize_distance: + Selects a subsample with a maximum distance between acq.argmax() points + n_evals: + Number of evaluations (how many times a ramdom subsample is drawn) + when maximizing distance between maxima of different EIs in a batch. + subsample_size: + Size of the subsample from the GP model's MCMC samples. + indices: + Indices of the input points. + + """ + + return compute_batch_acquisition( + poi, rng_key, model, X, xi, maximize, noiseless, + maximize_distance=maximize_distance, + n_evals=n_evals, subsample_size=subsample_size, + indices=indices, **kwargs) \ No newline at end of file diff --git a/tests/test_acq.py b/tests/test_acq.py index c1e5c39..f34ee31 100644 --- a/tests/test_acq.py +++ b/tests/test_acq.py @@ -10,8 +10,9 @@ from gpax.models.gp import ExactGP from gpax.models.vidkl import viDKL from gpax.utils import get_keys -from gpax.acquisition import ei, ucb, poi, EI, UCB, UE, Thompson -from gpax.acquisition.penalties import compute_penalty, penalty_point, find_and_replace_point_indices +from gpax.acquisition.base_acq import ei, ucb, poi +from gpax.acquisition import EI, UCB, UE, Thompson +from gpax.acquisition.penalties import compute_penalty @pytest.mark.parametrize("base_acq", [ei, ucb, poi])