-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Add batche-mode acquisition functions - Move ei, ucb, and poi to base_acq - Update tests
- Loading branch information
1 parent
ed68fe3
commit 9720a86
Showing
5 changed files
with
361 additions
and
272 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from .acquisition import * | ||
from .acquisition import UCB, EI, POI, UE, Thompson | ||
from .batch_acquisition import qEI, qPOI, qUCB | ||
|
||
__all__ = ["UCB", "EI", "POI", "UE", "Thompson", "qEI", "qPOI", "qUCB"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,135 +7,19 @@ | |
Created by Maxim Ziatdinov (email: [email protected]) | ||
""" | ||
|
||
from typing import Type, Optional, Dict, Callable, Any | ||
from typing import Type, Optional, Callable, Any | ||
|
||
import jax.numpy as jnp | ||
import jax.random as jra | ||
from jax import vmap | ||
import numpy as onp | ||
import numpyro.distributions as dist | ||
|
||
from ..models.gp import ExactGP | ||
from ..utils import random_sample_dict | ||
from .base_acq import ei, ucb, poi | ||
from .penalties import compute_penalty | ||
|
||
|
||
def ei(model: Type[ExactGP], | ||
X: jnp.ndarray, | ||
sample: Dict[str, jnp.ndarray], | ||
maximize: bool = False, | ||
noiseless: bool = False, | ||
**kwargs) -> jnp.ndarray: | ||
r""" | ||
Expected Improvement | ||
Args: | ||
model: trained model | ||
X: new inputs with shape (N, D), where D is a feature dimension | ||
sample: a single sample with model parameters | ||
maximize: If True, assumes that BO is solving maximization problem | ||
noiseless: | ||
Noise-free prediction. It is set to False by default as new/unseen data is assumed | ||
to follow the same distribution as the training data. Hence, since we introduce a model noise | ||
for the training data, we also want to include that noise in our prediction. | ||
**jitter: | ||
Small positive term added to the diagonal part of a covariance | ||
matrix for numerical stability (Default: 1e-6) | ||
""" | ||
if not isinstance(sample, (tuple, list)): | ||
sample = (sample,) | ||
# Get predictive mean and covariance for a single sample with kernel parameters | ||
pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) | ||
# Compute standard deviation | ||
sigma = jnp.sqrt(cov.diagonal()) | ||
# Standard EI computation | ||
best_f = pred.max() if maximize else pred.min() | ||
u = (pred - best_f) / sigma | ||
if not maximize: | ||
u = -u | ||
normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) | ||
ucdf = normal.cdf(u) | ||
updf = jnp.exp(normal.log_prob(u)) | ||
acq = sigma * (updf + u * ucdf) | ||
return acq | ||
|
||
|
||
def ucb(model: Type[ExactGP], | ||
X: jnp.ndarray, | ||
sample: Dict[str, jnp.ndarray], | ||
beta: float = 0.25, | ||
maximize: bool = False, | ||
noiseless: bool = False, | ||
**kwargs) -> jnp.ndarray: | ||
r""" | ||
Upper confidence bound | ||
Args: | ||
model: trained model | ||
X: new inputs with shape (N, D), where D is a feature dimension | ||
sample: a single sample with model parameters | ||
beta: coefficient balancing exploration-exploitation trade-off | ||
maximize: If True, assumes that BO is solving maximization problem | ||
noiseless: | ||
Noise-free prediction. It is set to False by default as new/unseen data is assumed | ||
to follow the same distribution as the training data. Hence, since we introduce a model noise | ||
for the training data, we also want to include that noise in our prediction. | ||
**jitter: | ||
Small positive term added to the diagonal part of a covariance | ||
matrix for numerical stability (Default: 1e-6) | ||
""" | ||
if not isinstance(sample, (tuple, list)): | ||
sample = (sample,) | ||
# Get predictive mean and covariance for a single sample with kernel parameters | ||
mean, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) | ||
var = cov.diagonal() | ||
delta = jnp.sqrt(beta * var) | ||
if maximize: | ||
acq = mean + delta | ||
else: | ||
acq = delta - mean # we return a negative acq for argmax in BO | ||
return acq | ||
|
||
|
||
def poi(model: Type[ExactGP], | ||
X: jnp.ndarray, | ||
sample: Dict[str, jnp.ndarray], | ||
xi: float = 0.01, | ||
maximize: bool = False, | ||
noiseless: bool = False, | ||
**kwargs) -> jnp.ndarray: | ||
r""" | ||
Probability of Improvement | ||
Args: | ||
model: trained model | ||
X: new inputs with shape (N, D), where D is a feature dimension | ||
sample: a single sample with model parameters | ||
xi: Exploration-exploitation trade-off parameter (Defaults to 0.01) | ||
maximize: If True, assumes that BO is solving maximization problem | ||
noiseless: | ||
Noise-free prediction. It is set to False by default as new/unseen data is assumed | ||
to follow the same distribution as the training data. Hence, since we introduce a model noise | ||
for the training data, we also want to include that noise in our prediction. | ||
**jitter: | ||
Small positive term added to the diagonal part of a covariance | ||
matrix for numerical stability (Default: 1e-6) | ||
""" | ||
if not isinstance(sample, (tuple, list)): | ||
sample = (sample,) | ||
# Get predictive mean and covariance for a single sample with kernel parameters | ||
pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs) | ||
# Compute standard deviation | ||
sigma = jnp.sqrt(cov.diagonal()) | ||
# Standard computation of poi | ||
best_f = pred.max() if maximize else pred.min() | ||
u = (pred - best_f - xi) / sigma | ||
if not maximize: | ||
u = -u | ||
normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u)) | ||
return normal.cdf(u) | ||
|
||
|
||
def compute_acquisition( | ||
model: Type[ExactGP], | ||
X: jnp.ndarray, | ||
|
@@ -192,9 +76,43 @@ def compute_acquisition( | |
return acq | ||
|
||
|
||
def compute_batched_acquisition(): | ||
# TBA | ||
return | ||
def compute_batch_acquisition(acquisition_type: Callable, | ||
model: Type[ExactGP], | ||
X: jnp.ndarray, | ||
*acq_args, | ||
maximize_distance: bool = False, | ||
n_evals: int = 1, | ||
subsample_size: int = 1, | ||
indices: Optional[jnp.ndarray] = None, | ||
**kwargs) -> jnp.ndarray: | ||
|
||
""" | ||
Batch-mode acquisition function fo a given type | ||
""" | ||
|
||
if model.mcmc is None: | ||
raise ValueError("The model needs to be fully Bayesian") | ||
|
||
samples = random_sample_dict(model.get_samples(), subsample_size) | ||
f = vmap(acquisition_type, in_axes=(None, None, 0) + (None,) * len(acq_args)) | ||
|
||
if not maximize_distance: | ||
acq = f(model, X, samples, *acq_args, **kwargs) | ||
else: | ||
X_ = jnp.array(indices) if indices is not None else jnp.array(X) | ||
acq_all, dist_all = [], [] | ||
|
||
for _ in range(n_evals): | ||
acq = f(model, X_, samples, *acq_args, **kwargs) | ||
points = acq.argmax(-1) | ||
d = jnp.linalg.norm(points).mean() | ||
acq_all.append(acq) | ||
dist_all.append(d) | ||
|
||
idx = jnp.array(dist_all).argmax() | ||
acq = acq_all[idx] | ||
|
||
return acq | ||
|
||
|
||
def EI(rng_key: jnp.ndarray, model: Type[ExactGP], | ||
|
@@ -257,22 +175,6 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP], | |
grid_indices=grid_indices, penalty_factor=penalty_factor, | ||
**kwargs) | ||
|
||
# if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): | ||
# raise ValueError("Please provide an array of recently visited points") | ||
# X = X[:, None] if X.ndim < 2 else X | ||
# samples = model.get_samples() | ||
# if model.mcmc is None: | ||
# acq = ei(model, X, samples, maximize, noiseless, **kwargs) | ||
# else: | ||
# f = vmap(ei, in_axes=(None, None, 0, None, None)) | ||
# acq = f(model, X, samples, maximize, noiseless, **kwargs) | ||
# acq = acq.mean(0) | ||
# if penalty: | ||
# X_ = grid_indices if grid_indices is not None else X | ||
# penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) | ||
# acq -= penalties | ||
# return acq | ||
|
||
|
||
def POI(rng_key: jnp.ndarray, | ||
model: Type[ExactGP], | ||
|
@@ -400,58 +302,6 @@ def UCB(rng_key: jnp.ndarray, | |
penalty=penalty, recent_points=recent_points, | ||
grid_indices=grid_indices, penalty_factor=penalty_factor, | ||
**kwargs) | ||
# if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): | ||
# raise ValueError("Please provide an array of recently visited points") | ||
# X = X[:, None] if X.ndim < 2 else X | ||
# samples = model.get_samples() | ||
# if model.mcmc is None: | ||
# acq = ucb(model, X, samples, beta, maximize, noiseless, **kwargs) | ||
# else: | ||
# f = vmap(ucb, in_axes=(None, None, 0, None, None)) | ||
# acq = f(model, X, samples, maximize, noiseless, **kwargs) | ||
# acq = acq.mean(0) | ||
# if penalty: | ||
# X_ = grid_indices if grid_indices is not None else X | ||
# penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) | ||
# acq -= penalties | ||
# return acq | ||
|
||
|
||
|
||
def qEI(rng_key: jnp.ndarray, | ||
model: Type[ExactGP], | ||
X: jnp.ndarray, | ||
maximize: bool = False, n: int = 1, | ||
noiseless: bool = False, | ||
maximize_distance: bool = False, | ||
n_evals: int = 1, | ||
subsample_size: int = 1, | ||
indices: Optional[jnp.ndarray] = None, | ||
**kwargs) -> jnp.ndarray: | ||
|
||
if model.mcmc is None: | ||
raise ValueError("qEI works only with fully Bayesian models") | ||
|
||
if not maximize_distance: | ||
samples = random_sample_dict(model.get_samples(), subsample_size) | ||
f = vmap(ei, in_axes=(None, None, 0, None, None)) | ||
acq = f(model, X, samples, maximize, noiseless, **kwargs) | ||
|
||
else: # draws samples multiple times and selects the ones where maxima are farthest apart from each other | ||
X_ = jnp.array(indices) if indices is not None else jnp.array(X) | ||
acq_all, dist_all = [], [] | ||
for _ in range(n_evals): | ||
samples = random_sample_dict(model.get_samples(), subsample_size) | ||
f = vmap(ei, in_axes=(None, None, 0, None, None)) | ||
acq = f(model, X_, samples, maximize, noiseless, **kwargs) # (subsample_size, len(X)) | ||
points = acq.argmax(-1) | ||
d = jnp.linalg.norm(points).mean() | ||
acq_all.append(acq) | ||
dist_all.append(d) | ||
idx = jnp.array(dist_all).argmax() | ||
acq = acq_all[idx] | ||
|
||
return acq | ||
|
||
|
||
def UE(rng_key: jnp.ndarray, | ||
|
@@ -555,83 +405,4 @@ def Thompson(rng_key: jnp.ndarray, | |
else: | ||
_, tsample = model.sample_from_posterior( | ||
rng_key, X, n=1, noiseless=noiseless, **kwargs) | ||
return tsample | ||
|
||
|
||
# def qUCB(rng_key: jnp.ndarray, model: Type[ExactGP], | ||
# X: jnp.ndarray, indices: Optional[jnp.ndarray] = None, | ||
# qbatch_size: int = 4, alpha: float = 1.0, beta: float = .25, | ||
# maximize: bool = True, n: int = 500, | ||
# n_restarts: int = 20, noiseless: bool = False, | ||
# **kwargs) -> jnp.ndarray: | ||
# """ | ||
# The acquisition function defined as alpha * mu + sqrt(beta) * sigma | ||
# that can output a "batch" of next points to evaluate. It takes advantage of | ||
# the fact that in MCMC-based GP or DKL we obtain a separate multivariate | ||
# normal posterior for each set of sampled kernel hyperparameters. | ||
|
||
# Args: | ||
# rng_key: random number generator key | ||
# model: ExactGP or DKL type of model | ||
# X: input array | ||
# indices: indices of data points in X array. For example, if | ||
# each data point is an image patch, the indices should | ||
# correspond to their (x, y) coordinates in the original image. | ||
# qbatch_size: desired number of sampled points (default: 4) | ||
# alpha: coefficient before mean prediction term (default: 1.0) | ||
# beta: coefficient before variance term (default: 0.25) | ||
# maximize: sign of variance term (+/- if True/False) | ||
# n: number of draws from each multivariate normal posterior | ||
# n_restarts: number of restarts to find a batch of maximally | ||
# separated points to evaluate next | ||
# noiseless: noise-free prediction for new/test data (default: False) | ||
|
||
# Returns: | ||
# Computed acquisition function with qbatch x features | ||
# or task x qbatch x features dimensions | ||
# """ | ||
# if model.mcmc is None: | ||
# raise NotImplementedError( | ||
# "Currently supports only ExactGP and DKL with MCMC inference") | ||
# dist_all, obj_all = [], [] | ||
# X_ = jnp.array(indices) if indices is not None else jnp.array(X) | ||
# for _ in range(n_restarts): | ||
# y_sampled = obtain_samples( | ||
# rng_key, model, X, qbatch_size, n, noiseless, **kwargs) | ||
# mean, var = y_sampled.mean(1), y_sampled.var(1) | ||
# delta = jnp.sqrt(beta * var) | ||
# if maximize: | ||
# obj = alpha * mean + delta | ||
# points = X_[obj.argmax(-1)] | ||
# else: | ||
# obj = alpha * mean - delta | ||
# points = X_[obj.argmin(-1)] | ||
# d = jnp.linalg.norm(points, axis=-1).mean(0) | ||
# dist_all.append(d) | ||
# obj_all.append(obj) | ||
# idx = jnp.array(dist_all).argmax(0) | ||
# if idx.ndim > 0: | ||
# obj_all = jnp.array(obj_all) | ||
# return jnp.array([obj_all[j,:,i] for i, j in enumerate(idx)]) | ||
# return obj_all[idx] | ||
|
||
|
||
# def obtain_samples(rng_key: jnp.ndarray, model: Type[ExactGP], | ||
# X: jnp.ndarray, qbatch_size: int = 4, | ||
# n: int = 500, noiseless: bool = False, | ||
# **kwargs) -> jnp.ndarray: | ||
# xbatch_size = kwargs.get("xbatch_size", 100) | ||
# posterior_samples = model.get_samples() | ||
# idx = onp.arange(0, len(posterior_samples["k_length"])) | ||
# onp.random.shuffle(idx) | ||
# idx = idx[:qbatch_size] | ||
# samples = {k: v[idx] for (k, v) in posterior_samples.items()} | ||
# if X.shape[0] > xbatch_size: | ||
# _, y_sampled = model.predict( | ||
# rng_key, X, samples, n, | ||
# noiseless=noiseless, **kwargs) | ||
# else: | ||
# _, y_sampled = model.predict_in_batches( | ||
# rng_key, X, xbatch_size, samples, n, | ||
# noiseless=noiseless, **kwargs) | ||
# return y_sampled | ||
return tsample |
Oops, something went wrong.