Skip to content

Commit

Permalink
Merge pull request #60 from tillahoffmann/del-jtheta
Browse files Browse the repository at this point in the history
Simplify implementation of periodic kernels.
  • Loading branch information
tillahoffmann authored Dec 16, 2022
2 parents 45f08a0 + 5365df3 commit 799ef25
Show file tree
Hide file tree
Showing 17 changed files with 57 additions and 232 deletions.
4 changes: 0 additions & 4 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,6 @@ more-itertools==9.0.0
# -r gptools-torch/test_requirements.txt
# -r gptools-util/test_requirements.txt
# jaraco-classes
mpmath==1.2.1
# via
# -r doc_requirements.txt
# -r gptools-util/test_requirements.txt
myst-nb==0.17.1
# via -r doc_requirements.txt
myst-parser==0.18.1
Expand Down
2 changes: 0 additions & 2 deletions doc_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,6 @@ mistune==2.0.4
# via nbconvert
more-itertools==9.0.0
# via jaraco-classes
mpmath==1.2.1
# via gp-tools-util
myst-nb==0.17.1
# via -r doc_requirements.in
myst-parser==0.18.1
Expand Down
4 changes: 1 addition & 3 deletions gptools-stan/docs/poisson_regression/poisson_regression.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ Given the substantial performance improvements, we can readily increase the samp

```{code-cell} ipython3
x = np.arange(1024)
num_terms = 3 if os.environ.get("CI") else None
kernel = ExpQuadKernel(sigma=1.2, length_scale=15, period=x.size, num_terms=num_terms) \
+ DiagonalKernel(1e-3, x.size)
kernel = ExpQuadKernel(sigma=1.2, length_scale=15, period=x.size) + DiagonalKernel(1e-3, x.size)
sample = simulate(x, kernel)
plot_sample(sample).legend()
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ parameters {
model {
// Gaussian process prior and observation model.
matrix[n, n] cov = add_diag(gp_periodic_exp_quad_cov(X, X, sigma, rep_vector(length_scale, 1),
rep_vector(n, 1), 10), epsilon);
rep_vector(n, 1)), epsilon);
eta ~ multi_normal(zeros_vector(n), cov);
y ~ poisson_log(eta);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ parameters {

transformed parameters {
// Evaluate covariance of the point at zero with everything else.
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10)
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n)
+ epsilon;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ parameters {

transformed parameters {
vector[n] eta;
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10)
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n)
+ epsilon;
eta = gp_transform_inv_rfft(z, zeros_vector(n), cov_rfft);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ transformed parameters {
// wrap the evaluation in braces because Stan only writes top-level variables to the output
// CSV files, and we don't need to store the entire covariance matrix.
matrix[n, n] cov = add_diag(gp_periodic_exp_quad_cov(
X, X, sigma, rep_vector(length_scale, 1), rep_vector(n, 1), 10), epsilon);
X, X, sigma, rep_vector(length_scale, 1), rep_vector(n, 1)), epsilon);
eta = cholesky_decompose(cov) * z;
}
}
Expand Down
30 changes: 9 additions & 21 deletions gptools-stan/gptools/stan/gptools_kernels.stan
Original file line number Diff line number Diff line change
Expand Up @@ -80,39 +80,27 @@ matrix dist2(array [] vector x, vector period, vector scale) {
Evaluate the periodic squared exponential kernel.
*/
matrix gp_periodic_exp_quad_cov(array [] vector x1, array [] vector x2, real sigma,
vector length_scale, vector period, int nterms) {
vector length_scale, vector period) {
int m = size(x1);
int n = size(x2);
matrix[m, n] result;
vector[size(length_scale)] time = 2 * (pi() * length_scale ./ period) ^ 2;
vector[size(length_scale)] q = exp(-time);
real scale = sigma * sigma * prod(sqrt(time / pi()));
for (i in 1:m) {
for (j in 1:n) {
result[i, j] = scale * prod(jtheta(evaluate_residuals(x1[i], x2[j], period, period), q, nterms));
}
}
return result;
return sigma * sigma * exp(-dist2(x1, x2, period, length_scale) / 2);
}

/**
Evaluate the real fast Fourier transform of the periodic squared exponential kernel.
*/
vector gp_periodic_exp_quad_cov_rfft(int n, real sigma, real length_scale, real period,
int nterms) {
real time = 2 * (pi() * length_scale / period) ^ 2;
return sigma * sigma * jtheta_rfft(n, exp(-time), nterms) * sqrt(time / pi());
vector gp_periodic_exp_quad_cov_rfft(int n, real sigma, real length_scale, real period) {
int nrfft = n %/% 2 + 1;
return n * sigma ^ 2 * length_scale / period * sqrt(2 * pi())
* exp(-2 * (pi() * linspaced_vector(nrfft, 0, nrfft - 1) * length_scale / period) ^ 2);
}

/**
Evaluate the two-dimensional real fast Fourier transform of the periodic squared exponential kernel.
*/
matrix gp_periodic_exp_quad_cov_rfft2(int m, int n, real sigma, vector length_scale, vector period,
int nterms) {
vector[m %/% 2 + 1] rfftm = gp_periodic_exp_quad_cov_rfft(m, sigma, length_scale[1], period[1],
nterms);
vector[n %/% 2 + 1] rfftn = gp_periodic_exp_quad_cov_rfft(n, 1, length_scale[2], period[2],
nterms);
matrix gp_periodic_exp_quad_cov_rfft2(int m, int n, real sigma, vector length_scale, vector period) {
vector[m %/% 2 + 1] rfftm = gp_periodic_exp_quad_cov_rfft(m, sigma, length_scale[1], period[1]);
vector[n %/% 2 + 1] rfftn = gp_periodic_exp_quad_cov_rfft(n, 1, length_scale[2], period[2]);
return get_real(expand_rfft(rfftm, m)) * rfftn';
}

Expand Down
20 changes: 0 additions & 20 deletions gptools-stan/gptools/stan/gptools_util.stan
Original file line number Diff line number Diff line change
Expand Up @@ -425,23 +425,3 @@ real std_normal_lpdf(complex_vector z) {
real std_normal_lpdf(matrix z) {
return std_normal_lpdf(to_vector(z));
}

// Special functions -------------------------------------------------------------------------------

vector jtheta(vector z, vector q, int nterms) {
vector[size(z)] result = zeros_vector(size(z));
for (n in 1:nterms) {
result += q ^ (n ^ 2) .* cos(2 * pi() * z * n);
}
return 1 + 2 * result;
}

vector jtheta_rfft(int nz, real q, int nterms) {
int nrfft = nz %/% 2 + 1;
vector[nrfft] result = zeros_vector(nrfft);
vector[nrfft] k = linspaced_vector(nrfft, 0, nrfft - 1);
for (n in 0:nterms) {
result += q ^ ((k + n * nz) ^ 2) + q ^ ((nz - k + n * nz) ^ 2);
}
return nz * result;
}
2 changes: 1 addition & 1 deletion gptools-stan/gptools/stan/profile/fourier_centered.stan
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ parameters {
}

model {
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10);
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n);
eta ~ gp_rfft(zeros_vector(n), cov_rfft);
y[observed_idx] ~ normal(eta[observed_idx], noise_scale);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ parameters {
transformed parameters {
vector[n] eta;
{
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10);
vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n);
eta = gp_transform_inv_rfft(eta_, zeros_vector(n), cov_rfft);
}
}
Expand Down
34 changes: 6 additions & 28 deletions gptools-stan/tests/test_stan_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,26 +359,6 @@ def assert_stan_python_allclose(
"desired": np.zeros((m, n)),
})

for n in [5, 8]:
z = np.linspace(0, 1, n, endpoint=False)
q = np.random.uniform(0.25, 0.75, n)
add_configuration({
"stan_function": "jtheta",
"arg_types": {"n_": "int", "z": "vector[n_]", "q": "vector[n_]", "nterms": "int"},
"arg_values": {"n_": n, "z": z, "q": q, "nterms": 10},
"result_type": "vector[n_]",
"includes": ["gptools_util.stan"],
"desired": kernels.jtheta(z, q),
})
add_configuration({
"stan_function": "jtheta_rfft",
"arg_types": {"n": "int", "q": "real", "nterms": "int"},
"arg_values": {"n": n, "q": q[0], "nterms": 10},
"result_type": "vector[n %/% 2 + 1]",
"includes": ["gptools_util.stan"],
"desired": kernels.jtheta_rfft(n, q[0]),
})

for ndim in [1, 2, 3]:
n = 1 + np.random.poisson(50)
m = 1 + np.random.poisson(50)
Expand All @@ -392,9 +372,9 @@ def assert_stan_python_allclose(
"stan_function": "gp_periodic_exp_quad_cov",
"arg_types": {"n_": "int", "m_": "int", "p_": "int", "x": "array [n_] vector[p_]",
"y": "array [m_] vector[p_]", "sigma": "real", "length_scale": "vector[p_]",
"period": "vector[p_]", "nterms": "int"},
"period": "vector[p_]"},
"arg_values": {"n_": n, "m_": m, "p_": ndim, "x": x, "y": y, "sigma": sigma,
"length_scale": length_scale, "period": period, "nterms": 100},
"length_scale": length_scale, "period": period},
"result_type": "matrix[n_, m_]",
"includes": ["gptools_util.stan", "gptools_kernels.stan"],
"desired": kernel.evaluate(x[:, None], y[None]),
Expand Down Expand Up @@ -428,10 +408,8 @@ def assert_stan_python_allclose(
period = np.random.gamma(100, 0.1)
add_configuration({
"stan_function": "gp_periodic_exp_quad_cov_rfft",
"arg_types": {"m": "int", "sigma": "real", "length_scale": "real", "period": "real",
"nterms": "int"},
"arg_values": {"m": n, "sigma": sigma, "length_scale": length_scale, "period": period,
"nterms": 100},
"arg_types": {"m": "int", "sigma": "real", "length_scale": "real", "period": "real"},
"arg_values": {"m": n, "sigma": sigma, "length_scale": length_scale, "period": period},
"result_type": "vector[m %/% 2 + 1]",
"includes": ["gptools_util.stan", "gptools_kernels.stan"],
"desired": kernels.ExpQuadKernel(sigma, length_scale, period=period).evaluate_rfft([n]),
Expand All @@ -453,9 +431,9 @@ def assert_stan_python_allclose(
add_configuration({
"stan_function": "gp_periodic_exp_quad_cov_rfft2",
"arg_types": {"m": "int", "n": "int", "sigma": "real", "length_scale": "vector[2]",
"period": "vector[2]", "nterms": "int"},
"period": "vector[2]"},
"arg_values": {"m": m, "n": n, "sigma": sigma, "length_scale": length_scale,
"period": period, "nterms": 100},
"period": period},
"result_type": "matrix[m, n %/% 2 + 1]",
"includes": ["gptools_util.stan", "gptools_kernels.stan"],
"desired": kernels.ExpQuadKernel(sigma, length_scale, period=period)
Expand Down
122 changes: 20 additions & 102 deletions gptools-util/gptools/util/kernels.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,14 @@
import math
import numbers
import numpy as np
import operator
from typing import Callable, Optional
from typing import Callable
from . import ArrayOrTensor, ArrayOrTensorDispatch, coordgrid, OptionalArrayOrTensor
from .fft import expand_rfft


dispatch = ArrayOrTensorDispatch()


def _jtheta_num_terms(q: ArrayOrTensor, rtol: float = 1e-9) -> int:
return math.ceil(math.log(rtol) / math.log(dispatch.max(q)))


def jtheta(z: ArrayOrTensor, q: ArrayOrTensor, nterms: Optional[int] = None,
max_batch_size: int = 1e6) -> ArrayOrTensor:
r"""
Evaluate the Jacobi theta function using a series approximation.
.. math::
\vartheta_3\left(q,z\right) = 1 + 2 \sum_{n=1}^\infty q^{n^2} \cos\left(2\pi n z\right)
Args:
z: Argument of the theta function.
q: Nome of the theta function with modulus less than one.
nterms: Number of terms in the series approximation (defaults to achieve a relative
tolerance of :math:`10^{-9}`, 197 terms for `q = 0.9`).
max_batch_size: Maximum number of terms per batch.
"""
# TODO: fix for torch.
q, z = np.broadcast_arrays(q, z)
nterms = nterms or _jtheta_num_terms(q)
# If the dimensions of q and z are large and the number of terms is large, we can run into
# memory issues here. We batch the evaluation if necessary to overcome this issue. The maximum
# number of terms should be no more than 10 ^ 6 elements (about 8MB at 64-bit precision).
batch_size = int(max(1, max_batch_size / (q.size * nterms)))
series = 0.0
for offset in range(0, nterms, batch_size):
size = min(batch_size, nterms - offset)
n = 1 + offset + dispatch[z].arange(size)
series = series \
+ (q[..., None] ** (n ** 2) * dispatch.cos(2 * math.pi * z[..., None] * n)).sum(axis=-1)

return 1 + 2 * series


def jtheta_rfft(nz: int, q: ArrayOrTensor, nterms: Optional[int] = None) -> ArrayOrTensor:
"""
Evaluate the real fast Fourier transform of the Jacobi theta function evaluated on the unit
interval with `nz` grid points.
The :func:`jtheta` and :func:`jtheta_rfft` functions are related by
>>> nz = ...
>>> q = ...
>>> z = np.linspace(0, 1, nz, endpoint=False)
>>> np.fft.rfft(jtheta(z, q)) == jtheta_rfft(nz, q)
Args:
nz: Number of grid points.
q: Nome of the theta function with modulus less than one.
nterms: Number of terms in the series approximation (defaults to achieve a relative
tolerance of :math:`10^{-9}`, 197 terms for `q = 0.9`).
"""
nterms = nterms or _jtheta_num_terms(q)
k = np.arange(nz // 2 + 1)
ns = nz * np.arange(nterms)[:, None]
return nz * ((q ** ((k + ns) ** 2)).sum(axis=0) + (q ** ((nz - k + ns) ** 2)).sum(axis=0))


def evaluate_residuals(x: ArrayOrTensor, y: OptionalArrayOrTensor = None,
period: OptionalArrayOrTensor = None) -> ArrayOrTensor:
"""
Expand Down Expand Up @@ -290,45 +228,28 @@ class ExpQuadKernel(Kernel):
sigma: Scale of the covariance.
length_scale: Correlation length.
period: Period for circular boundary conditions.
num_terms: Number of terms in the series approximation of the heat equation solution.
"""
def __init__(self, sigma: float, length_scale: float, period: OptionalArrayOrTensor = None,
num_terms: Optional[int] = None) -> None:
def __init__(self, sigma: float, length_scale: float, period: OptionalArrayOrTensor = None) \
-> None:
super().__init__(period)
self.sigma = sigma
self.length_scale = length_scale
if self.is_periodic:
# Evaluate the effective relaxation time of the heat kernel.
self.time = 2 * (math.pi * self.length_scale / self.period) ** 2
if num_terms is None:
num_terms = _jtheta_num_terms(dispatch.exp(-self.time).max())
if not isinstance(num_terms, numbers.Number):
num_terms = max(num_terms)
self.num_terms = int(num_terms)
else:
self.time = self.num_terms = None

def evaluate(self, x: ArrayOrTensor, y: OptionalArrayOrTensor = None) -> ArrayOrTensor:
if self.is_periodic:
# The residuals will have shape `(..., num_dims)`.
residuals = evaluate_residuals(x, y, self.period) / self.period
value = jtheta(residuals, dispatch.exp(-self.time), self.num_terms) \
* (self.time / math.pi) ** 0.5
cov = self.sigma ** 2 * value.prod(axis=-1)
return cov
else:
residuals = evaluate_residuals(x, y) / self.length_scale
exponent = - dispatch.square(residuals).sum(axis=-1) / 2
return self.sigma * self.sigma * dispatch.exp(exponent)
residuals = evaluate_residuals(x, y, self.period) / self.length_scale
exponent = - dispatch.square(residuals).sum(axis=-1) / 2
return self.sigma * self.sigma * dispatch.exp(exponent)

def evaluate_rfft(self, shape: tuple[int]) -> ArrayOrTensor:
if not self.is_periodic:
raise ValueError("kernel must be periodic")
raise NotImplementedError
ndim = len(shape)
time = self.time * np.ones(ndim)
rescaled_length_scale = self.length_scale * np.ones(ndim) / self.period
value = None
for i, size in enumerate(shape):
part = jtheta_rfft(size, np.exp(-time[i])) * (time[i] / math.pi) ** 0.5
xi = np.arange(size // 2 + 1)
part = size * rescaled_length_scale[i] * (2 * math.pi) ** 0.5 \
* np.exp(-2 * (math.pi * xi * rescaled_length_scale[i]) ** 2)
if i != ndim - 1:
part = expand_rfft(part, size)
if value is None:
Expand Down Expand Up @@ -359,22 +280,19 @@ def __init__(self, dof: float, sigma: float, length_scale: float,
self.dof = dof

def evaluate(self, x: ArrayOrTensor, y: OptionalArrayOrTensor = None) -> ArrayOrTensor:
if self.is_periodic:
raise NotImplementedError
residuals = evaluate_residuals(x, y, self.period) / self.length_scale
distance = (2 * self.dof * residuals * residuals).sum(axis=-1) ** 0.5
if self.dof == 3 / 2:
value = 1 + distance
elif self.dof == 5 / 2:
value = 1 + distance + distance * distance / 3
else:
residuals = evaluate_residuals(x, y) / self.length_scale
distance = (2 * self.dof * residuals * residuals).sum(axis=-1) ** 0.5
if self.dof == 3 / 2:
value = 1 + distance
elif self.dof == 5 / 2:
value = 1 + distance + distance * distance / 3
else:
raise NotImplementedError
return self.sigma * self.sigma * value * dispatch.exp(-distance)
raise NotImplementedError
return self.sigma * self.sigma * value * dispatch.exp(-distance)

def evaluate_rfft(self, shape: tuple[int]):
if not self.is_periodic:
raise ValueError("kernel must be periodic")
raise NotImplementedError
from scipy import special

# Construct the grid to evaluate on.
Expand Down
1 change: 0 additions & 1 deletion gptools-util/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"flake8",
"jupyter",
"matplotlib",
"mpmath",
"networkx",
"pytest",
"pytest-cov",
Expand Down
Loading

0 comments on commit 799ef25

Please sign in to comment.