diff --git a/dev_requirements.txt b/dev_requirements.txt index 6de130f..d3571dc 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -502,10 +502,6 @@ more-itertools==9.0.0 # -r gptools-torch/test_requirements.txt # -r gptools-util/test_requirements.txt # jaraco-classes -mpmath==1.2.1 - # via - # -r doc_requirements.txt - # -r gptools-util/test_requirements.txt myst-nb==0.17.1 # via -r doc_requirements.txt myst-parser==0.18.1 diff --git a/doc_requirements.txt b/doc_requirements.txt index ab478f8..cd80ffd 100644 --- a/doc_requirements.txt +++ b/doc_requirements.txt @@ -224,8 +224,6 @@ mistune==2.0.4 # via nbconvert more-itertools==9.0.0 # via jaraco-classes -mpmath==1.2.1 - # via gp-tools-util myst-nb==0.17.1 # via -r doc_requirements.in myst-parser==0.18.1 diff --git a/gptools-stan/docs/poisson_regression/poisson_regression.md b/gptools-stan/docs/poisson_regression/poisson_regression.md index edd9acf..577b64d 100644 --- a/gptools-stan/docs/poisson_regression/poisson_regression.md +++ b/gptools-stan/docs/poisson_regression/poisson_regression.md @@ -236,9 +236,7 @@ Given the substantial performance improvements, we can readily increase the samp ```{code-cell} ipython3 x = np.arange(1024) -num_terms = 3 if os.environ.get("CI") else None -kernel = ExpQuadKernel(sigma=1.2, length_scale=15, period=x.size, num_terms=num_terms) \ - + DiagonalKernel(1e-3, x.size) +kernel = ExpQuadKernel(sigma=1.2, length_scale=15, period=x.size) + DiagonalKernel(1e-3, x.size) sample = simulate(x, kernel) plot_sample(sample).legend() ``` diff --git a/gptools-stan/docs/poisson_regression/poisson_regression_centered.stan b/gptools-stan/docs/poisson_regression/poisson_regression_centered.stan index b6c7df4..19dcc2b 100644 --- a/gptools-stan/docs/poisson_regression/poisson_regression_centered.stan +++ b/gptools-stan/docs/poisson_regression/poisson_regression_centered.stan @@ -21,7 +21,7 @@ parameters { model { // Gaussian process prior and observation model. matrix[n, n] cov = add_diag(gp_periodic_exp_quad_cov(X, X, sigma, rep_vector(length_scale, 1), - rep_vector(n, 1), 10), epsilon); + rep_vector(n, 1)), epsilon); eta ~ multi_normal(zeros_vector(n), cov); y ~ poisson_log(eta); } diff --git a/gptools-stan/docs/poisson_regression/poisson_regression_fourier_centered.stan b/gptools-stan/docs/poisson_regression/poisson_regression_fourier_centered.stan index 908cb55..77ac744 100644 --- a/gptools-stan/docs/poisson_regression/poisson_regression_fourier_centered.stan +++ b/gptools-stan/docs/poisson_regression/poisson_regression_fourier_centered.stan @@ -16,7 +16,7 @@ parameters { transformed parameters { // Evaluate covariance of the point at zero with everything else. - vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10) + vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n) + epsilon; } diff --git a/gptools-stan/docs/poisson_regression/poisson_regression_fourier_non_centered.stan b/gptools-stan/docs/poisson_regression/poisson_regression_fourier_non_centered.stan index 0c79642..4986ed6 100644 --- a/gptools-stan/docs/poisson_regression/poisson_regression_fourier_non_centered.stan +++ b/gptools-stan/docs/poisson_regression/poisson_regression_fourier_non_centered.stan @@ -16,7 +16,7 @@ parameters { transformed parameters { vector[n] eta; - vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10) + vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n) + epsilon; eta = gp_transform_inv_rfft(z, zeros_vector(n), cov_rfft); } diff --git a/gptools-stan/docs/poisson_regression/poisson_regression_non_centered.stan b/gptools-stan/docs/poisson_regression/poisson_regression_non_centered.stan index 6858ad9..cd87983 100644 --- a/gptools-stan/docs/poisson_regression/poisson_regression_non_centered.stan +++ b/gptools-stan/docs/poisson_regression/poisson_regression_non_centered.stan @@ -20,7 +20,7 @@ transformed parameters { // wrap the evaluation in braces because Stan only writes top-level variables to the output // CSV files, and we don't need to store the entire covariance matrix. matrix[n, n] cov = add_diag(gp_periodic_exp_quad_cov( - X, X, sigma, rep_vector(length_scale, 1), rep_vector(n, 1), 10), epsilon); + X, X, sigma, rep_vector(length_scale, 1), rep_vector(n, 1)), epsilon); eta = cholesky_decompose(cov) * z; } } diff --git a/gptools-stan/gptools/stan/gptools_kernels.stan b/gptools-stan/gptools/stan/gptools_kernels.stan index 198220a..ec77493 100644 --- a/gptools-stan/gptools/stan/gptools_kernels.stan +++ b/gptools-stan/gptools/stan/gptools_kernels.stan @@ -80,39 +80,27 @@ matrix dist2(array [] vector x, vector period, vector scale) { Evaluate the periodic squared exponential kernel. */ matrix gp_periodic_exp_quad_cov(array [] vector x1, array [] vector x2, real sigma, - vector length_scale, vector period, int nterms) { + vector length_scale, vector period) { int m = size(x1); int n = size(x2); - matrix[m, n] result; - vector[size(length_scale)] time = 2 * (pi() * length_scale ./ period) ^ 2; - vector[size(length_scale)] q = exp(-time); - real scale = sigma * sigma * prod(sqrt(time / pi())); - for (i in 1:m) { - for (j in 1:n) { - result[i, j] = scale * prod(jtheta(evaluate_residuals(x1[i], x2[j], period, period), q, nterms)); - } - } - return result; + return sigma * sigma * exp(-dist2(x1, x2, period, length_scale) / 2); } /** Evaluate the real fast Fourier transform of the periodic squared exponential kernel. */ -vector gp_periodic_exp_quad_cov_rfft(int n, real sigma, real length_scale, real period, - int nterms) { - real time = 2 * (pi() * length_scale / period) ^ 2; - return sigma * sigma * jtheta_rfft(n, exp(-time), nterms) * sqrt(time / pi()); +vector gp_periodic_exp_quad_cov_rfft(int n, real sigma, real length_scale, real period) { + int nrfft = n %/% 2 + 1; + return n * sigma ^ 2 * length_scale / period * sqrt(2 * pi()) + * exp(-2 * (pi() * linspaced_vector(nrfft, 0, nrfft - 1) * length_scale / period) ^ 2); } /** Evaluate the two-dimensional real fast Fourier transform of the periodic squared exponential kernel. */ -matrix gp_periodic_exp_quad_cov_rfft2(int m, int n, real sigma, vector length_scale, vector period, - int nterms) { - vector[m %/% 2 + 1] rfftm = gp_periodic_exp_quad_cov_rfft(m, sigma, length_scale[1], period[1], - nterms); - vector[n %/% 2 + 1] rfftn = gp_periodic_exp_quad_cov_rfft(n, 1, length_scale[2], period[2], - nterms); +matrix gp_periodic_exp_quad_cov_rfft2(int m, int n, real sigma, vector length_scale, vector period) { + vector[m %/% 2 + 1] rfftm = gp_periodic_exp_quad_cov_rfft(m, sigma, length_scale[1], period[1]); + vector[n %/% 2 + 1] rfftn = gp_periodic_exp_quad_cov_rfft(n, 1, length_scale[2], period[2]); return get_real(expand_rfft(rfftm, m)) * rfftn'; } diff --git a/gptools-stan/gptools/stan/gptools_util.stan b/gptools-stan/gptools/stan/gptools_util.stan index 08b36df..c4ca807 100644 --- a/gptools-stan/gptools/stan/gptools_util.stan +++ b/gptools-stan/gptools/stan/gptools_util.stan @@ -425,23 +425,3 @@ real std_normal_lpdf(complex_vector z) { real std_normal_lpdf(matrix z) { return std_normal_lpdf(to_vector(z)); } - -// Special functions ------------------------------------------------------------------------------- - -vector jtheta(vector z, vector q, int nterms) { - vector[size(z)] result = zeros_vector(size(z)); - for (n in 1:nterms) { - result += q ^ (n ^ 2) .* cos(2 * pi() * z * n); - } - return 1 + 2 * result; -} - -vector jtheta_rfft(int nz, real q, int nterms) { - int nrfft = nz %/% 2 + 1; - vector[nrfft] result = zeros_vector(nrfft); - vector[nrfft] k = linspaced_vector(nrfft, 0, nrfft - 1); - for (n in 0:nterms) { - result += q ^ ((k + n * nz) ^ 2) + q ^ ((nz - k + n * nz) ^ 2); - } - return nz * result; -} diff --git a/gptools-stan/gptools/stan/profile/fourier_centered.stan b/gptools-stan/gptools/stan/profile/fourier_centered.stan index d80dfcc..34607f3 100644 --- a/gptools-stan/gptools/stan/profile/fourier_centered.stan +++ b/gptools-stan/gptools/stan/profile/fourier_centered.stan @@ -13,7 +13,7 @@ parameters { } model { - vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10); + vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n); eta ~ gp_rfft(zeros_vector(n), cov_rfft); y[observed_idx] ~ normal(eta[observed_idx], noise_scale); } diff --git a/gptools-stan/gptools/stan/profile/fourier_non_centered.stan b/gptools-stan/gptools/stan/profile/fourier_non_centered.stan index 524696e..d83f184 100644 --- a/gptools-stan/gptools/stan/profile/fourier_non_centered.stan +++ b/gptools-stan/gptools/stan/profile/fourier_non_centered.stan @@ -15,7 +15,7 @@ parameters { transformed parameters { vector[n] eta; { - vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n, 10); + vector[n %/% 2 + 1] cov_rfft = gp_periodic_exp_quad_cov_rfft(n, sigma, length_scale, n); eta = gp_transform_inv_rfft(eta_, zeros_vector(n), cov_rfft); } } diff --git a/gptools-stan/tests/test_stan_functions.py b/gptools-stan/tests/test_stan_functions.py index 87501e4..f7b463c 100644 --- a/gptools-stan/tests/test_stan_functions.py +++ b/gptools-stan/tests/test_stan_functions.py @@ -359,26 +359,6 @@ def assert_stan_python_allclose( "desired": np.zeros((m, n)), }) -for n in [5, 8]: - z = np.linspace(0, 1, n, endpoint=False) - q = np.random.uniform(0.25, 0.75, n) - add_configuration({ - "stan_function": "jtheta", - "arg_types": {"n_": "int", "z": "vector[n_]", "q": "vector[n_]", "nterms": "int"}, - "arg_values": {"n_": n, "z": z, "q": q, "nterms": 10}, - "result_type": "vector[n_]", - "includes": ["gptools_util.stan"], - "desired": kernels.jtheta(z, q), - }) - add_configuration({ - "stan_function": "jtheta_rfft", - "arg_types": {"n": "int", "q": "real", "nterms": "int"}, - "arg_values": {"n": n, "q": q[0], "nterms": 10}, - "result_type": "vector[n %/% 2 + 1]", - "includes": ["gptools_util.stan"], - "desired": kernels.jtheta_rfft(n, q[0]), - }) - for ndim in [1, 2, 3]: n = 1 + np.random.poisson(50) m = 1 + np.random.poisson(50) @@ -392,9 +372,9 @@ def assert_stan_python_allclose( "stan_function": "gp_periodic_exp_quad_cov", "arg_types": {"n_": "int", "m_": "int", "p_": "int", "x": "array [n_] vector[p_]", "y": "array [m_] vector[p_]", "sigma": "real", "length_scale": "vector[p_]", - "period": "vector[p_]", "nterms": "int"}, + "period": "vector[p_]"}, "arg_values": {"n_": n, "m_": m, "p_": ndim, "x": x, "y": y, "sigma": sigma, - "length_scale": length_scale, "period": period, "nterms": 100}, + "length_scale": length_scale, "period": period}, "result_type": "matrix[n_, m_]", "includes": ["gptools_util.stan", "gptools_kernels.stan"], "desired": kernel.evaluate(x[:, None], y[None]), @@ -428,10 +408,8 @@ def assert_stan_python_allclose( period = np.random.gamma(100, 0.1) add_configuration({ "stan_function": "gp_periodic_exp_quad_cov_rfft", - "arg_types": {"m": "int", "sigma": "real", "length_scale": "real", "period": "real", - "nterms": "int"}, - "arg_values": {"m": n, "sigma": sigma, "length_scale": length_scale, "period": period, - "nterms": 100}, + "arg_types": {"m": "int", "sigma": "real", "length_scale": "real", "period": "real"}, + "arg_values": {"m": n, "sigma": sigma, "length_scale": length_scale, "period": period}, "result_type": "vector[m %/% 2 + 1]", "includes": ["gptools_util.stan", "gptools_kernels.stan"], "desired": kernels.ExpQuadKernel(sigma, length_scale, period=period).evaluate_rfft([n]), @@ -453,9 +431,9 @@ def assert_stan_python_allclose( add_configuration({ "stan_function": "gp_periodic_exp_quad_cov_rfft2", "arg_types": {"m": "int", "n": "int", "sigma": "real", "length_scale": "vector[2]", - "period": "vector[2]", "nterms": "int"}, + "period": "vector[2]"}, "arg_values": {"m": m, "n": n, "sigma": sigma, "length_scale": length_scale, - "period": period, "nterms": 100}, + "period": period}, "result_type": "matrix[m, n %/% 2 + 1]", "includes": ["gptools_util.stan", "gptools_kernels.stan"], "desired": kernels.ExpQuadKernel(sigma, length_scale, period=period) diff --git a/gptools-util/gptools/util/kernels.py b/gptools-util/gptools/util/kernels.py index 08f1880..89c2616 100644 --- a/gptools-util/gptools/util/kernels.py +++ b/gptools-util/gptools/util/kernels.py @@ -1,8 +1,7 @@ import math -import numbers import numpy as np import operator -from typing import Callable, Optional +from typing import Callable from . import ArrayOrTensor, ArrayOrTensorDispatch, coordgrid, OptionalArrayOrTensor from .fft import expand_rfft @@ -10,67 +9,6 @@ dispatch = ArrayOrTensorDispatch() -def _jtheta_num_terms(q: ArrayOrTensor, rtol: float = 1e-9) -> int: - return math.ceil(math.log(rtol) / math.log(dispatch.max(q))) - - -def jtheta(z: ArrayOrTensor, q: ArrayOrTensor, nterms: Optional[int] = None, - max_batch_size: int = 1e6) -> ArrayOrTensor: - r""" - Evaluate the Jacobi theta function using a series approximation. - - .. math:: - - \vartheta_3\left(q,z\right) = 1 + 2 \sum_{n=1}^\infty q^{n^2} \cos\left(2\pi n z\right) - - Args: - z: Argument of the theta function. - q: Nome of the theta function with modulus less than one. - nterms: Number of terms in the series approximation (defaults to achieve a relative - tolerance of :math:`10^{-9}`, 197 terms for `q = 0.9`). - max_batch_size: Maximum number of terms per batch. - """ - # TODO: fix for torch. - q, z = np.broadcast_arrays(q, z) - nterms = nterms or _jtheta_num_terms(q) - # If the dimensions of q and z are large and the number of terms is large, we can run into - # memory issues here. We batch the evaluation if necessary to overcome this issue. The maximum - # number of terms should be no more than 10 ^ 6 elements (about 8MB at 64-bit precision). - batch_size = int(max(1, max_batch_size / (q.size * nterms))) - series = 0.0 - for offset in range(0, nterms, batch_size): - size = min(batch_size, nterms - offset) - n = 1 + offset + dispatch[z].arange(size) - series = series \ - + (q[..., None] ** (n ** 2) * dispatch.cos(2 * math.pi * z[..., None] * n)).sum(axis=-1) - - return 1 + 2 * series - - -def jtheta_rfft(nz: int, q: ArrayOrTensor, nterms: Optional[int] = None) -> ArrayOrTensor: - """ - Evaluate the real fast Fourier transform of the Jacobi theta function evaluated on the unit - interval with `nz` grid points. - - The :func:`jtheta` and :func:`jtheta_rfft` functions are related by - - >>> nz = ... - >>> q = ... - >>> z = np.linspace(0, 1, nz, endpoint=False) - >>> np.fft.rfft(jtheta(z, q)) == jtheta_rfft(nz, q) - - Args: - nz: Number of grid points. - q: Nome of the theta function with modulus less than one. - nterms: Number of terms in the series approximation (defaults to achieve a relative - tolerance of :math:`10^{-9}`, 197 terms for `q = 0.9`). - """ - nterms = nterms or _jtheta_num_terms(q) - k = np.arange(nz // 2 + 1) - ns = nz * np.arange(nterms)[:, None] - return nz * ((q ** ((k + ns) ** 2)).sum(axis=0) + (q ** ((nz - k + ns) ** 2)).sum(axis=0)) - - def evaluate_residuals(x: ArrayOrTensor, y: OptionalArrayOrTensor = None, period: OptionalArrayOrTensor = None) -> ArrayOrTensor: """ @@ -290,45 +228,28 @@ class ExpQuadKernel(Kernel): sigma: Scale of the covariance. length_scale: Correlation length. period: Period for circular boundary conditions. - num_terms: Number of terms in the series approximation of the heat equation solution. """ - def __init__(self, sigma: float, length_scale: float, period: OptionalArrayOrTensor = None, - num_terms: Optional[int] = None) -> None: + def __init__(self, sigma: float, length_scale: float, period: OptionalArrayOrTensor = None) \ + -> None: super().__init__(period) self.sigma = sigma self.length_scale = length_scale - if self.is_periodic: - # Evaluate the effective relaxation time of the heat kernel. - self.time = 2 * (math.pi * self.length_scale / self.period) ** 2 - if num_terms is None: - num_terms = _jtheta_num_terms(dispatch.exp(-self.time).max()) - if not isinstance(num_terms, numbers.Number): - num_terms = max(num_terms) - self.num_terms = int(num_terms) - else: - self.time = self.num_terms = None def evaluate(self, x: ArrayOrTensor, y: OptionalArrayOrTensor = None) -> ArrayOrTensor: - if self.is_periodic: - # The residuals will have shape `(..., num_dims)`. - residuals = evaluate_residuals(x, y, self.period) / self.period - value = jtheta(residuals, dispatch.exp(-self.time), self.num_terms) \ - * (self.time / math.pi) ** 0.5 - cov = self.sigma ** 2 * value.prod(axis=-1) - return cov - else: - residuals = evaluate_residuals(x, y) / self.length_scale - exponent = - dispatch.square(residuals).sum(axis=-1) / 2 - return self.sigma * self.sigma * dispatch.exp(exponent) + residuals = evaluate_residuals(x, y, self.period) / self.length_scale + exponent = - dispatch.square(residuals).sum(axis=-1) / 2 + return self.sigma * self.sigma * dispatch.exp(exponent) def evaluate_rfft(self, shape: tuple[int]) -> ArrayOrTensor: if not self.is_periodic: - raise ValueError("kernel must be periodic") + raise NotImplementedError ndim = len(shape) - time = self.time * np.ones(ndim) + rescaled_length_scale = self.length_scale * np.ones(ndim) / self.period value = None for i, size in enumerate(shape): - part = jtheta_rfft(size, np.exp(-time[i])) * (time[i] / math.pi) ** 0.5 + xi = np.arange(size // 2 + 1) + part = size * rescaled_length_scale[i] * (2 * math.pi) ** 0.5 \ + * np.exp(-2 * (math.pi * xi * rescaled_length_scale[i]) ** 2) if i != ndim - 1: part = expand_rfft(part, size) if value is None: @@ -359,22 +280,19 @@ def __init__(self, dof: float, sigma: float, length_scale: float, self.dof = dof def evaluate(self, x: ArrayOrTensor, y: OptionalArrayOrTensor = None) -> ArrayOrTensor: - if self.is_periodic: - raise NotImplementedError + residuals = evaluate_residuals(x, y, self.period) / self.length_scale + distance = (2 * self.dof * residuals * residuals).sum(axis=-1) ** 0.5 + if self.dof == 3 / 2: + value = 1 + distance + elif self.dof == 5 / 2: + value = 1 + distance + distance * distance / 3 else: - residuals = evaluate_residuals(x, y) / self.length_scale - distance = (2 * self.dof * residuals * residuals).sum(axis=-1) ** 0.5 - if self.dof == 3 / 2: - value = 1 + distance - elif self.dof == 5 / 2: - value = 1 + distance + distance * distance / 3 - else: - raise NotImplementedError - return self.sigma * self.sigma * value * dispatch.exp(-distance) + raise NotImplementedError + return self.sigma * self.sigma * value * dispatch.exp(-distance) def evaluate_rfft(self, shape: tuple[int]): if not self.is_periodic: - raise ValueError("kernel must be periodic") + raise NotImplementedError from scipy import special # Construct the grid to evaluate on. diff --git a/gptools-util/setup.py b/gptools-util/setup.py index c142716..a8e1f14 100644 --- a/gptools-util/setup.py +++ b/gptools-util/setup.py @@ -18,7 +18,6 @@ "flake8", "jupyter", "matplotlib", - "mpmath", "networkx", "pytest", "pytest-cov", diff --git a/gptools-util/test_requirements.txt b/gptools-util/test_requirements.txt index b9db652..1b622c8 100644 --- a/gptools-util/test_requirements.txt +++ b/gptools-util/test_requirements.txt @@ -169,8 +169,6 @@ mistune==2.0.4 # via nbconvert more-itertools==9.0.0 # via jaraco-classes -mpmath==1.2.1 - # via gp-tools-util nbclassic==0.4.5 # via notebook nbclient==0.5.13 diff --git a/gptools-util/tests/test_fft.py b/gptools-util/tests/test_fft.py index 1bd51dd..5667ded 100644 --- a/gptools-util/tests/test_fft.py +++ b/gptools-util/tests/test_fft.py @@ -62,11 +62,15 @@ def test_transform_rfft_roundtrip(batch_shape: tuple[int], rfft_num: int, use_to np.testing.assert_allclose(z, x) -def test_evaluate_log_prob_rfft2(batch_shape: tuple[int], rfft2_shape: int, use_torch: bool) \ - -> None: - xs = coordgrid(*(np.linspace(0, 1, size, endpoint=False) for size in rfft2_shape)) - kernel = kernels.ExpQuadKernel(np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1) \ - + kernels.DiagonalKernel(1e-2, 1) +@pytest.mark.parametrize("kernel", [ + kernels.ExpQuadKernel(np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1.2), + kernels.MaternKernel(1.5, np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1.3), + kernels.MaternKernel(2.5, np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1.4), +]) +def test_evaluate_log_prob_rfft2(kernel: kernels.Kernel, batch_shape: tuple[int], rfft2_shape: int, + use_torch: bool) -> None: + xs = coordgrid(*(np.linspace(0, kernel.period, size, endpoint=False) for size in rfft2_shape)) + kernel = kernel + kernels.DiagonalKernel(1e-2, kernel.period) cov = kernel.evaluate(xs) loc = np.random.normal(0, 1, xs.shape[0]) dist = stats.multivariate_normal(loc, cov) diff --git a/gptools-util/tests/test_kernels.py b/gptools-util/tests/test_kernels.py index 71d642a..bdd7eb5 100644 --- a/gptools-util/tests/test_kernels.py +++ b/gptools-util/tests/test_kernels.py @@ -1,38 +1,12 @@ from gptools.util import coordgrid, kernels from gptools.util.testing import KernelConfiguration import itertools as it -import mpmath import numpy as np import pytest from scipy.spatial.distance import cdist import torch as th -@pytest.mark.parametrize("q", [0.1, 0.8]) -def test_jtheta(q: float) -> None: - z = np.linspace(0, 1, 7, endpoint=False) - actual = kernels.jtheta(z, q, max_batch_size=3) - desired = np.vectorize(mpmath.jtheta)(3, np.pi * z, q).astype(float) - np.testing.assert_allclose(actual, desired) - - -def test_jtheta_batching() -> None: - z = np.linspace(0, 1, 7, endpoint=False) - result = kernels.jtheta(0.5, z, nterms=13) - for max_batch_size in [7, 13, 14, 21]: - np.testing.assert_allclose(result, kernels.jtheta(0.5, z, nterms=13, - max_batch_size=max_batch_size)) - - -@pytest.mark.parametrize("nz", [5, 6]) -@pytest.mark.parametrize("q", [0.1, 0.8]) -def test_jtheta_rfft(nz: int, q: float) -> None: - jtheta = kernels.jtheta(np.linspace(0, 1, nz, endpoint=False), q) - actual = kernels.jtheta_rfft(nz, q) - desired = np.fft.rfft(jtheta) - np.testing.assert_allclose(actual, desired) - - @pytest.mark.parametrize("shape", [(7,), (2, 3)]) def test_kernel(kernel_configuration: KernelConfiguration, shape: tuple, use_torch: bool) -> None: kernel = kernel_configuration() @@ -55,7 +29,7 @@ def test_periodic(kernel_configuration: KernelConfiguration): kernel = kernel_configuration() if not kernel.is_periodic: # Ensure the rfft cannot be evaluated and skip the rest. - with pytest.raises((NotImplementedError, ValueError)): + with pytest.raises(NotImplementedError): kernel.evaluate_rfft(tuple(range(13, 13 + len(kernel_configuration.dims)))) return @@ -85,8 +59,11 @@ def test_periodic(kernel_configuration: KernelConfiguration): fftcov = np.fft.rfft(cov) if dim == 1 else np.fft.rfft2(cov) np.testing.assert_allclose(fftcov.imag, 0, atol=1e-9) - # Ensure the numeric rfft matches the manual evaluation. - np.testing.assert_allclose(fftcov.real, kernel.evaluate_rfft(shape), atol=1e-9) + # We may want to check that the numerical and theoretic FFT match, but this requires a more + # "proper" implementation of the periodic kernels involving infinite sums (see + # https://github.com/tillahoffmann/gp-tools/issues/59 for details). For now, let's verify we + # have a positive-definite kernel. + np.testing.assert_array_less(-1e-12, kernel.evaluate_rfft(shape)) def test_kernel_composition(): @@ -134,15 +111,6 @@ def test_periodic_exp_quad_rfft(shape: int) -> None: raise ValueError assert rfft.shape == (*head, tail // 2 + 1,) np.testing.assert_allclose(rfft.imag, 0, atol=1e-9) - rfft = rfft.real - predicted = kernel.evaluate_rfft(shape) - np.testing.assert_allclose(rfft, predicted, atol=1e-9) - - -@pytest.mark.parametrize("num_terms", [None, 7, np.arange(4)]) -def test_periodic_exp_quad_kernel_num_terms(num_terms) -> None: - kernel = kernels.ExpQuadKernel(1, .5, 1, num_terms) - assert kernel.num_terms >= 1 def test_matern_invalid_dof() -> None: