Skip to content

Commit

Permalink
Refactor acquisition functions
Browse files Browse the repository at this point in the history
- Add batche-mode acquisition functions
- Move ei, ucb, and poi to base_acq
- Update tests
  • Loading branch information
ziatdinovmax committed Aug 20, 2023
1 parent ed68fe3 commit 9720a86
Show file tree
Hide file tree
Showing 5 changed files with 361 additions and 272 deletions.
5 changes: 4 additions & 1 deletion gpax/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .acquisition import *
from .acquisition import UCB, EI, POI, UE, Thompson
from .batch_acquisition import qEI, qPOI, qUCB

__all__ = ["UCB", "EI", "POI", "UE", "Thompson", "qEI", "qPOI", "qUCB"]
309 changes: 40 additions & 269 deletions gpax/acquisition/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,135 +7,19 @@
Created by Maxim Ziatdinov (email: [email protected])
"""

from typing import Type, Optional, Dict, Callable, Any
from typing import Type, Optional, Callable, Any

import jax.numpy as jnp
import jax.random as jra
from jax import vmap
import numpy as onp
import numpyro.distributions as dist

from ..models.gp import ExactGP
from ..utils import random_sample_dict
from .base_acq import ei, ucb, poi
from .penalties import compute_penalty


def ei(model: Type[ExactGP],
X: jnp.ndarray,
sample: Dict[str, jnp.ndarray],
maximize: bool = False,
noiseless: bool = False,
**kwargs) -> jnp.ndarray:
r"""
Expected Improvement
Args:
model: trained model
X: new inputs with shape (N, D), where D is a feature dimension
sample: a single sample with model parameters
maximize: If True, assumes that BO is solving maximization problem
noiseless:
Noise-free prediction. It is set to False by default as new/unseen data is assumed
to follow the same distribution as the training data. Hence, since we introduce a model noise
for the training data, we also want to include that noise in our prediction.
**jitter:
Small positive term added to the diagonal part of a covariance
matrix for numerical stability (Default: 1e-6)
"""
if not isinstance(sample, (tuple, list)):
sample = (sample,)
# Get predictive mean and covariance for a single sample with kernel parameters
pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs)
# Compute standard deviation
sigma = jnp.sqrt(cov.diagonal())
# Standard EI computation
best_f = pred.max() if maximize else pred.min()
u = (pred - best_f) / sigma
if not maximize:
u = -u
normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u))
ucdf = normal.cdf(u)
updf = jnp.exp(normal.log_prob(u))
acq = sigma * (updf + u * ucdf)
return acq


def ucb(model: Type[ExactGP],
X: jnp.ndarray,
sample: Dict[str, jnp.ndarray],
beta: float = 0.25,
maximize: bool = False,
noiseless: bool = False,
**kwargs) -> jnp.ndarray:
r"""
Upper confidence bound
Args:
model: trained model
X: new inputs with shape (N, D), where D is a feature dimension
sample: a single sample with model parameters
beta: coefficient balancing exploration-exploitation trade-off
maximize: If True, assumes that BO is solving maximization problem
noiseless:
Noise-free prediction. It is set to False by default as new/unseen data is assumed
to follow the same distribution as the training data. Hence, since we introduce a model noise
for the training data, we also want to include that noise in our prediction.
**jitter:
Small positive term added to the diagonal part of a covariance
matrix for numerical stability (Default: 1e-6)
"""
if not isinstance(sample, (tuple, list)):
sample = (sample,)
# Get predictive mean and covariance for a single sample with kernel parameters
mean, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs)
var = cov.diagonal()
delta = jnp.sqrt(beta * var)
if maximize:
acq = mean + delta
else:
acq = delta - mean # we return a negative acq for argmax in BO
return acq


def poi(model: Type[ExactGP],
X: jnp.ndarray,
sample: Dict[str, jnp.ndarray],
xi: float = 0.01,
maximize: bool = False,
noiseless: bool = False,
**kwargs) -> jnp.ndarray:
r"""
Probability of Improvement
Args:
model: trained model
X: new inputs with shape (N, D), where D is a feature dimension
sample: a single sample with model parameters
xi: Exploration-exploitation trade-off parameter (Defaults to 0.01)
maximize: If True, assumes that BO is solving maximization problem
noiseless:
Noise-free prediction. It is set to False by default as new/unseen data is assumed
to follow the same distribution as the training data. Hence, since we introduce a model noise
for the training data, we also want to include that noise in our prediction.
**jitter:
Small positive term added to the diagonal part of a covariance
matrix for numerical stability (Default: 1e-6)
"""
if not isinstance(sample, (tuple, list)):
sample = (sample,)
# Get predictive mean and covariance for a single sample with kernel parameters
pred, cov = model.get_mvn_posterior(X, *sample, noiseless, **kwargs)
# Compute standard deviation
sigma = jnp.sqrt(cov.diagonal())
# Standard computation of poi
best_f = pred.max() if maximize else pred.min()
u = (pred - best_f - xi) / sigma
if not maximize:
u = -u
normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u))
return normal.cdf(u)


def compute_acquisition(
model: Type[ExactGP],
X: jnp.ndarray,
Expand Down Expand Up @@ -192,9 +76,43 @@ def compute_acquisition(
return acq


def compute_batched_acquisition():
# TBA
return
def compute_batch_acquisition(acquisition_type: Callable,
model: Type[ExactGP],
X: jnp.ndarray,
*acq_args,
maximize_distance: bool = False,
n_evals: int = 1,
subsample_size: int = 1,
indices: Optional[jnp.ndarray] = None,
**kwargs) -> jnp.ndarray:

"""
Batch-mode acquisition function fo a given type
"""

if model.mcmc is None:
raise ValueError("The model needs to be fully Bayesian")

samples = random_sample_dict(model.get_samples(), subsample_size)
f = vmap(acquisition_type, in_axes=(None, None, 0) + (None,) * len(acq_args))

if not maximize_distance:
acq = f(model, X, samples, *acq_args, **kwargs)
else:
X_ = jnp.array(indices) if indices is not None else jnp.array(X)
acq_all, dist_all = [], []

for _ in range(n_evals):
acq = f(model, X_, samples, *acq_args, **kwargs)
points = acq.argmax(-1)
d = jnp.linalg.norm(points).mean()
acq_all.append(acq)
dist_all.append(d)

idx = jnp.array(dist_all).argmax()
acq = acq_all[idx]

return acq


def EI(rng_key: jnp.ndarray, model: Type[ExactGP],
Expand Down Expand Up @@ -257,22 +175,6 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP],
grid_indices=grid_indices, penalty_factor=penalty_factor,
**kwargs)

# if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)):
# raise ValueError("Please provide an array of recently visited points")
# X = X[:, None] if X.ndim < 2 else X
# samples = model.get_samples()
# if model.mcmc is None:
# acq = ei(model, X, samples, maximize, noiseless, **kwargs)
# else:
# f = vmap(ei, in_axes=(None, None, 0, None, None))
# acq = f(model, X, samples, maximize, noiseless, **kwargs)
# acq = acq.mean(0)
# if penalty:
# X_ = grid_indices if grid_indices is not None else X
# penalties = compute_penalty(X_, recent_points, penalty, penalty_factor)
# acq -= penalties
# return acq


def POI(rng_key: jnp.ndarray,
model: Type[ExactGP],
Expand Down Expand Up @@ -400,58 +302,6 @@ def UCB(rng_key: jnp.ndarray,
penalty=penalty, recent_points=recent_points,
grid_indices=grid_indices, penalty_factor=penalty_factor,
**kwargs)
# if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)):
# raise ValueError("Please provide an array of recently visited points")
# X = X[:, None] if X.ndim < 2 else X
# samples = model.get_samples()
# if model.mcmc is None:
# acq = ucb(model, X, samples, beta, maximize, noiseless, **kwargs)
# else:
# f = vmap(ucb, in_axes=(None, None, 0, None, None))
# acq = f(model, X, samples, maximize, noiseless, **kwargs)
# acq = acq.mean(0)
# if penalty:
# X_ = grid_indices if grid_indices is not None else X
# penalties = compute_penalty(X_, recent_points, penalty, penalty_factor)
# acq -= penalties
# return acq



def qEI(rng_key: jnp.ndarray,
model: Type[ExactGP],
X: jnp.ndarray,
maximize: bool = False, n: int = 1,
noiseless: bool = False,
maximize_distance: bool = False,
n_evals: int = 1,
subsample_size: int = 1,
indices: Optional[jnp.ndarray] = None,
**kwargs) -> jnp.ndarray:

if model.mcmc is None:
raise ValueError("qEI works only with fully Bayesian models")

if not maximize_distance:
samples = random_sample_dict(model.get_samples(), subsample_size)
f = vmap(ei, in_axes=(None, None, 0, None, None))
acq = f(model, X, samples, maximize, noiseless, **kwargs)

else: # draws samples multiple times and selects the ones where maxima are farthest apart from each other
X_ = jnp.array(indices) if indices is not None else jnp.array(X)
acq_all, dist_all = [], []
for _ in range(n_evals):
samples = random_sample_dict(model.get_samples(), subsample_size)
f = vmap(ei, in_axes=(None, None, 0, None, None))
acq = f(model, X_, samples, maximize, noiseless, **kwargs) # (subsample_size, len(X))
points = acq.argmax(-1)
d = jnp.linalg.norm(points).mean()
acq_all.append(acq)
dist_all.append(d)
idx = jnp.array(dist_all).argmax()
acq = acq_all[idx]

return acq


def UE(rng_key: jnp.ndarray,
Expand Down Expand Up @@ -555,83 +405,4 @@ def Thompson(rng_key: jnp.ndarray,
else:
_, tsample = model.sample_from_posterior(
rng_key, X, n=1, noiseless=noiseless, **kwargs)
return tsample


# def qUCB(rng_key: jnp.ndarray, model: Type[ExactGP],
# X: jnp.ndarray, indices: Optional[jnp.ndarray] = None,
# qbatch_size: int = 4, alpha: float = 1.0, beta: float = .25,
# maximize: bool = True, n: int = 500,
# n_restarts: int = 20, noiseless: bool = False,
# **kwargs) -> jnp.ndarray:
# """
# The acquisition function defined as alpha * mu + sqrt(beta) * sigma
# that can output a "batch" of next points to evaluate. It takes advantage of
# the fact that in MCMC-based GP or DKL we obtain a separate multivariate
# normal posterior for each set of sampled kernel hyperparameters.

# Args:
# rng_key: random number generator key
# model: ExactGP or DKL type of model
# X: input array
# indices: indices of data points in X array. For example, if
# each data point is an image patch, the indices should
# correspond to their (x, y) coordinates in the original image.
# qbatch_size: desired number of sampled points (default: 4)
# alpha: coefficient before mean prediction term (default: 1.0)
# beta: coefficient before variance term (default: 0.25)
# maximize: sign of variance term (+/- if True/False)
# n: number of draws from each multivariate normal posterior
# n_restarts: number of restarts to find a batch of maximally
# separated points to evaluate next
# noiseless: noise-free prediction for new/test data (default: False)

# Returns:
# Computed acquisition function with qbatch x features
# or task x qbatch x features dimensions
# """
# if model.mcmc is None:
# raise NotImplementedError(
# "Currently supports only ExactGP and DKL with MCMC inference")
# dist_all, obj_all = [], []
# X_ = jnp.array(indices) if indices is not None else jnp.array(X)
# for _ in range(n_restarts):
# y_sampled = obtain_samples(
# rng_key, model, X, qbatch_size, n, noiseless, **kwargs)
# mean, var = y_sampled.mean(1), y_sampled.var(1)
# delta = jnp.sqrt(beta * var)
# if maximize:
# obj = alpha * mean + delta
# points = X_[obj.argmax(-1)]
# else:
# obj = alpha * mean - delta
# points = X_[obj.argmin(-1)]
# d = jnp.linalg.norm(points, axis=-1).mean(0)
# dist_all.append(d)
# obj_all.append(obj)
# idx = jnp.array(dist_all).argmax(0)
# if idx.ndim > 0:
# obj_all = jnp.array(obj_all)
# return jnp.array([obj_all[j,:,i] for i, j in enumerate(idx)])
# return obj_all[idx]


# def obtain_samples(rng_key: jnp.ndarray, model: Type[ExactGP],
# X: jnp.ndarray, qbatch_size: int = 4,
# n: int = 500, noiseless: bool = False,
# **kwargs) -> jnp.ndarray:
# xbatch_size = kwargs.get("xbatch_size", 100)
# posterior_samples = model.get_samples()
# idx = onp.arange(0, len(posterior_samples["k_length"]))
# onp.random.shuffle(idx)
# idx = idx[:qbatch_size]
# samples = {k: v[idx] for (k, v) in posterior_samples.items()}
# if X.shape[0] > xbatch_size:
# _, y_sampled = model.predict(
# rng_key, X, samples, n,
# noiseless=noiseless, **kwargs)
# else:
# _, y_sampled = model.predict_in_batches(
# rng_key, X, xbatch_size, samples, n,
# noiseless=noiseless, **kwargs)
# return y_sampled
return tsample
Loading

0 comments on commit 9720a86

Please sign in to comment.