Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify implementation of periodic kernels. #60

Merged
merged 5 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,6 @@ more-itertools==9.0.0
# -r gptools-torch/test_requirements.txt
# -r gptools-util/test_requirements.txt
# jaraco-classes
mpmath==1.2.1
# via
# -r doc_requirements.txt
# -r gptools-util/test_requirements.txt
myst-nb==0.17.1
# via -r doc_requirements.txt
myst-parser==0.18.1
Expand Down
2 changes: 0 additions & 2 deletions doc_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,6 @@ mistune==2.0.4
# via nbconvert
more-itertools==9.0.0
# via jaraco-classes
mpmath==1.2.1
# via gp-tools-util
myst-nb==0.17.1
# via -r doc_requirements.in
myst-parser==0.18.1
Expand Down
4 changes: 1 addition & 3 deletions gptools-stan/docs/poisson_regression/poisson_regression.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ Given the substantial performance improvements, we can readily increase the samp

```{code-cell} ipython3
x = np.arange(1024)
num_terms = 3 if os.environ.get("CI") else None
kernel = ExpQuadKernel(sigma=1.2, length_scale=15, period=x.size, num_terms=num_terms) \
+ DiagonalKernel(1e-3, x.size)
kernel = ExpQuadKernel(sigma=1.2, length_scale=15, period=x.size) + DiagonalKernel(1e-3, x.size)
sample = simulate(x, kernel)
plot_sample(sample).legend()
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ parameters {
model {
// Gaussian process prior and observation model.
matrix[n, n] cov = add_diag(gp_periodic_exp_quad_cov(X, X, sigma, rep_vector(length_scale, 1),
rep_vector(n, 1), 10), epsilon);
rep_vector(n, 1)), epsilon);
eta ~ multi_normal(zeros_vector(n), cov);
y ~ poisson_log(eta);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ parameters {

transformed parameters {
// Evaluate covariance of the point at zero with everything else.
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10)
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n)
+ epsilon;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ parameters {

transformed parameters {
vector[n] eta;
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10)
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n)
+ epsilon;
eta = gp_transform_inv_rfft(z, zeros_vector(n), cov_rfft);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ transformed parameters {
// wrap the evaluation in braces because Stan only writes top-level variables to the output
// CSV files, and we don't need to store the entire covariance matrix.
matrix[n, n] cov = add_diag(gp_periodic_exp_quad_cov(
X, X, sigma, rep_vector(length_scale, 1), rep_vector(n, 1), 10), epsilon);
X, X, sigma, rep_vector(length_scale, 1), rep_vector(n, 1)), epsilon);
eta = cholesky_decompose(cov) * z;
}
}
Expand Down
30 changes: 9 additions & 21 deletions gptools-stan/gptools/stan/gptools_kernels.stan
Original file line number Diff line number Diff line change
Expand Up @@ -80,39 +80,27 @@ matrix dist2(array [] vector x, vector period, vector scale) {
Evaluate the periodic squared exponential kernel.
*/
matrix gp_periodic_exp_quad_cov(array [] vector x1, array [] vector x2, real sigma,
vector length_scale, vector period, int nterms) {
vector length_scale, vector period) {
int m = size(x1);
int n = size(x2);
matrix[m, n] result;
vector[size(length_scale)] time = 2 * (pi() * length_scale ./ period) ^ 2;
vector[size(length_scale)] q = exp(-time);
real scale = sigma * sigma * prod(sqrt(time / pi()));
for (i in 1:m) {
for (j in 1:n) {
result[i, j] = scale * prod(jtheta(evaluate_residuals(x1[i], x2[j], period, period), q, nterms));
}
}
return result;
return sigma * sigma * exp(-dist2(x1, x2, period, length_scale) / 2);
}

/**
Evaluate the real fast Fourier transform of the periodic squared exponential kernel.
*/
vector gp_periodic_exp_quad_cov_rfft(int n, real sigma, real length_scale, real period,
int nterms) {
real time = 2 * (pi() * length_scale / period) ^ 2;
return sigma * sigma * jtheta_rfft(n, exp(-time), nterms) * sqrt(time / pi());
vector gp_periodic_exp_quad_cov_rfft(int n, real sigma, real length_scale, real period) {
int nrfft = n %/% 2 + 1;
return n * sigma ^ 2 * length_scale / period * sqrt(2 * pi())
* exp(-2 * (pi() * linspaced_vector(nrfft, 0, nrfft - 1) * length_scale / period) ^ 2);
}

/**
Evaluate the two-dimensional real fast Fourier transform of the periodic squared exponential kernel.
*/
matrix gp_periodic_exp_quad_cov_rfft2(int m, int n, real sigma, vector length_scale, vector period,
int nterms) {
vector[m %/% 2 + 1] rfftm = gp_periodic_exp_quad_cov_rfft(m, sigma, length_scale[1], period[1],
nterms);
vector[n %/% 2 + 1] rfftn = gp_periodic_exp_quad_cov_rfft(n, 1, length_scale[2], period[2],
nterms);
matrix gp_periodic_exp_quad_cov_rfft2(int m, int n, real sigma, vector length_scale, vector period) {
vector[m %/% 2 + 1] rfftm = gp_periodic_exp_quad_cov_rfft(m, sigma, length_scale[1], period[1]);
vector[n %/% 2 + 1] rfftn = gp_periodic_exp_quad_cov_rfft(n, 1, length_scale[2], period[2]);
return get_real(expand_rfft(rfftm, m)) * rfftn';
}

Expand Down
20 changes: 0 additions & 20 deletions gptools-stan/gptools/stan/gptools_util.stan
Original file line number Diff line number Diff line change
Expand Up @@ -425,23 +425,3 @@ real std_normal_lpdf(complex_vector z) {
real std_normal_lpdf(matrix z) {
return std_normal_lpdf(to_vector(z));
}

// Special functions -------------------------------------------------------------------------------

vector jtheta(vector z, vector q, int nterms) {
vector[size(z)] result = zeros_vector(size(z));
for (n in 1:nterms) {
result += q ^ (n ^ 2) .* cos(2 * pi() * z * n);
}
return 1 + 2 * result;
}

vector jtheta_rfft(int nz, real q, int nterms) {
int nrfft = nz %/% 2 + 1;
vector[nrfft] result = zeros_vector(nrfft);
vector[nrfft] k = linspaced_vector(nrfft, 0, nrfft - 1);
for (n in 0:nterms) {
result += q ^ ((k + n * nz) ^ 2) + q ^ ((nz - k + n * nz) ^ 2);
}
return nz * result;
}
2 changes: 1 addition & 1 deletion gptools-stan/gptools/stan/profile/fourier_centered.stan
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ parameters {
}

model {
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10);
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n);
eta ~ gp_rfft(zeros_vector(n), cov_rfft);
y[observed_idx] ~ normal(eta[observed_idx], noise_scale);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ parameters {
transformed parameters {
vector[n] eta;
{
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10);
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n);
eta = gp_transform_inv_rfft(eta_, zeros_vector(n), cov_rfft);
}
}
Expand Down
34 changes: 6 additions & 28 deletions gptools-stan/tests/test_stan_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,26 +359,6 @@ def assert_stan_python_allclose(
"desired": np.zeros((m, n)),
})

for n in [5, 8]:
z = np.linspace(0, 1, n, endpoint=False)
q = np.random.uniform(0.25, 0.75, n)
add_configuration({
"stan_function": "jtheta",
"arg_types": {"n_": "int", "z": "vector[n_]", "q": "vector[n_]", "nterms": "int"},
"arg_values": {"n_": n, "z": z, "q": q, "nterms": 10},
"result_type": "vector[n_]",
"includes": ["gptools_util.stan"],
"desired": kernels.jtheta(z, q),
})
add_configuration({
"stan_function": "jtheta_rfft",
"arg_types": {"n": "int", "q": "real", "nterms": "int"},
"arg_values": {"n": n, "q": q[0], "nterms": 10},
"result_type": "vector[n %/% 2 + 1]",
"includes": ["gptools_util.stan"],
"desired": kernels.jtheta_rfft(n, q[0]),
})

for ndim in [1, 2, 3]:
n = 1 + np.random.poisson(50)
m = 1 + np.random.poisson(50)
Expand All @@ -392,9 +372,9 @@ def assert_stan_python_allclose(
"stan_function": "gp_periodic_exp_quad_cov",
"arg_types": {"n_": "int", "m_": "int", "p_": "int", "x": "array [n_] vector[p_]",
"y": "array [m_] vector[p_]", "sigma": "real", "length_scale": "vector[p_]",
"period": "vector[p_]", "nterms": "int"},
"period": "vector[p_]"},
"arg_values": {"n_": n, "m_": m, "p_": ndim, "x": x, "y": y, "sigma": sigma,
"length_scale": length_scale, "period": period, "nterms": 100},
"length_scale": length_scale, "period": period},
"result_type": "matrix[n_, m_]",
"includes": ["gptools_util.stan", "gptools_kernels.stan"],
"desired": kernel.evaluate(x[:, None], y[None]),
Expand Down Expand Up @@ -428,10 +408,8 @@ def assert_stan_python_allclose(
period = np.random.gamma(100, 0.1)
add_configuration({
"stan_function": "gp_periodic_exp_quad_cov_rfft",
"arg_types": {"m": "int", "sigma": "real", "length_scale": "real", "period": "real",
"nterms": "int"},
"arg_values": {"m": n, "sigma": sigma, "length_scale": length_scale, "period": period,
"nterms": 100},
"arg_types": {"m": "int", "sigma": "real", "length_scale": "real", "period": "real"},
"arg_values": {"m": n, "sigma": sigma, "length_scale": length_scale, "period": period},
"result_type": "vector[m %/% 2 + 1]",
"includes": ["gptools_util.stan", "gptools_kernels.stan"],
"desired": kernels.ExpQuadKernel(sigma, length_scale, period=period).evaluate_rfft([n]),
Expand All @@ -453,9 +431,9 @@ def assert_stan_python_allclose(
add_configuration({
"stan_function": "gp_periodic_exp_quad_cov_rfft2",
"arg_types": {"m": "int", "n": "int", "sigma": "real", "length_scale": "vector[2]",
"period": "vector[2]", "nterms": "int"},
"period": "vector[2]"},
"arg_values": {"m": m, "n": n, "sigma": sigma, "length_scale": length_scale,
"period": period, "nterms": 100},
"period": period},
"result_type": "matrix[m, n %/% 2 + 1]",
"includes": ["gptools_util.stan", "gptools_kernels.stan"],
"desired": kernels.ExpQuadKernel(sigma, length_scale, period=period)
Expand Down
122 changes: 20 additions & 102 deletions gptools-util/gptools/util/kernels.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,14 @@
import math
import numbers
import numpy as np
import operator
from typing import Callable, Optional
from typing import Callable
from . import ArrayOrTensor, ArrayOrTensorDispatch, coordgrid, OptionalArrayOrTensor
from .fft import expand_rfft


dispatch = ArrayOrTensorDispatch()


def _jtheta_num_terms(q: ArrayOrTensor, rtol: float = 1e-9) -> int:
return math.ceil(math.log(rtol) / math.log(dispatch.max(q)))


def jtheta(z: ArrayOrTensor, q: ArrayOrTensor, nterms: Optional[int] = None,
max_batch_size: int = 1e6) -> ArrayOrTensor:
r"""
Evaluate the Jacobi theta function using a series approximation.

.. math::

\vartheta_3\left(q,z\right) = 1 + 2 \sum_{n=1}^\infty q^{n^2} \cos\left(2\pi n z\right)

Args:
z: Argument of the theta function.
q: Nome of the theta function with modulus less than one.
nterms: Number of terms in the series approximation (defaults to achieve a relative
tolerance of :math:`10^{-9}`, 197 terms for `q = 0.9`).
max_batch_size: Maximum number of terms per batch.
"""
# TODO: fix for torch.
q, z = np.broadcast_arrays(q, z)
nterms = nterms or _jtheta_num_terms(q)
# If the dimensions of q and z are large and the number of terms is large, we can run into
# memory issues here. We batch the evaluation if necessary to overcome this issue. The maximum
# number of terms should be no more than 10 ^ 6 elements (about 8MB at 64-bit precision).
batch_size = int(max(1, max_batch_size / (q.size * nterms)))
series = 0.0
for offset in range(0, nterms, batch_size):
size = min(batch_size, nterms - offset)
n = 1 + offset + dispatch[z].arange(size)
series = series \
+ (q[..., None] ** (n ** 2) * dispatch.cos(2 * math.pi * z[..., None] * n)).sum(axis=-1)

return 1 + 2 * series


def jtheta_rfft(nz: int, q: ArrayOrTensor, nterms: Optional[int] = None) -> ArrayOrTensor:
"""
Evaluate the real fast Fourier transform of the Jacobi theta function evaluated on the unit
interval with `nz` grid points.

The :func:`jtheta` and :func:`jtheta_rfft` functions are related by

>>> nz = ...
>>> q = ...
>>> z = np.linspace(0, 1, nz, endpoint=False)
>>> np.fft.rfft(jtheta(z, q)) == jtheta_rfft(nz, q)

Args:
nz: Number of grid points.
q: Nome of the theta function with modulus less than one.
nterms: Number of terms in the series approximation (defaults to achieve a relative
tolerance of :math:`10^{-9}`, 197 terms for `q = 0.9`).
"""
nterms = nterms or _jtheta_num_terms(q)
k = np.arange(nz // 2 + 1)
ns = nz * np.arange(nterms)[:, None]
return nz * ((q ** ((k + ns) ** 2)).sum(axis=0) + (q ** ((nz - k + ns) ** 2)).sum(axis=0))


def evaluate_residuals(x: ArrayOrTensor, y: OptionalArrayOrTensor = None,
period: OptionalArrayOrTensor = None) -> ArrayOrTensor:
"""
Expand Down Expand Up @@ -290,45 +228,28 @@ class ExpQuadKernel(Kernel):
sigma: Scale of the covariance.
length_scale: Correlation length.
period: Period for circular boundary conditions.
num_terms: Number of terms in the series approximation of the heat equation solution.
"""
def __init__(self, sigma: float, length_scale: float, period: OptionalArrayOrTensor = None,
num_terms: Optional[int] = None) -> None:
def __init__(self, sigma: float, length_scale: float, period: OptionalArrayOrTensor = None) \
-> None:
super().__init__(period)
self.sigma = sigma
self.length_scale = length_scale
if self.is_periodic:
# Evaluate the effective relaxation time of the heat kernel.
self.time = 2 * (math.pi * self.length_scale / self.period) ** 2
if num_terms is None:
num_terms = _jtheta_num_terms(dispatch.exp(-self.time).max())
if not isinstance(num_terms, numbers.Number):
num_terms = max(num_terms)
self.num_terms = int(num_terms)
else:
self.time = self.num_terms = None

def evaluate(self, x: ArrayOrTensor, y: OptionalArrayOrTensor = None) -> ArrayOrTensor:
if self.is_periodic:
# The residuals will have shape `(..., num_dims)`.
residuals = evaluate_residuals(x, y, self.period) / self.period
value = jtheta(residuals, dispatch.exp(-self.time), self.num_terms) \
* (self.time / math.pi) ** 0.5
cov = self.sigma ** 2 * value.prod(axis=-1)
return cov
else:
residuals = evaluate_residuals(x, y) / self.length_scale
exponent = - dispatch.square(residuals).sum(axis=-1) / 2
return self.sigma * self.sigma * dispatch.exp(exponent)
residuals = evaluate_residuals(x, y, self.period) / self.length_scale
exponent = - dispatch.square(residuals).sum(axis=-1) / 2
return self.sigma * self.sigma * dispatch.exp(exponent)

def evaluate_rfft(self, shape: tuple[int]) -> ArrayOrTensor:
if not self.is_periodic:
raise ValueError("kernel must be periodic")
raise NotImplementedError
ndim = len(shape)
time = self.time * np.ones(ndim)
rescaled_length_scale = self.length_scale * np.ones(ndim) / self.period
value = None
for i, size in enumerate(shape):
part = jtheta_rfft(size, np.exp(-time[i])) * (time[i] / math.pi) ** 0.5
xi = np.arange(size // 2 + 1)
part = size * rescaled_length_scale[i] * (2 * math.pi) ** 0.5 \
* np.exp(-2 * (math.pi * xi * rescaled_length_scale[i]) ** 2)
if i != ndim - 1:
part = expand_rfft(part, size)
if value is None:
Expand Down Expand Up @@ -359,22 +280,19 @@ def __init__(self, dof: float, sigma: float, length_scale: float,
self.dof = dof

def evaluate(self, x: ArrayOrTensor, y: OptionalArrayOrTensor = None) -> ArrayOrTensor:
if self.is_periodic:
raise NotImplementedError
residuals = evaluate_residuals(x, y, self.period) / self.length_scale
distance = (2 * self.dof * residuals * residuals).sum(axis=-1) ** 0.5
if self.dof == 3 / 2:
value = 1 + distance
elif self.dof == 5 / 2:
value = 1 + distance + distance * distance / 3
else:
residuals = evaluate_residuals(x, y) / self.length_scale
distance = (2 * self.dof * residuals * residuals).sum(axis=-1) ** 0.5
if self.dof == 3 / 2:
value = 1 + distance
elif self.dof == 5 / 2:
value = 1 + distance + distance * distance / 3
else:
raise NotImplementedError
return self.sigma * self.sigma * value * dispatch.exp(-distance)
raise NotImplementedError
return self.sigma * self.sigma * value * dispatch.exp(-distance)

def evaluate_rfft(self, shape: tuple[int]):
if not self.is_periodic:
raise ValueError("kernel must be periodic")
raise NotImplementedError
from scipy import special

# Construct the grid to evaluate on.
Expand Down
1 change: 0 additions & 1 deletion gptools-util/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"flake8",
"jupyter",
"matplotlib",
"mpmath",
"networkx",
"pytest",
"pytest-cov",
Expand Down
Loading