From 5365df36300bf6661015ec88e9722f1e3a8dd2dd Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 16 Dec 2022 15:39:35 -0500 Subject: [PATCH] Run more general tests for `evaluate_log_prob_rfft2`. --- gptools-util/tests/test_fft.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/gptools-util/tests/test_fft.py b/gptools-util/tests/test_fft.py index 6b5e987..5667ded 100644 --- a/gptools-util/tests/test_fft.py +++ b/gptools-util/tests/test_fft.py @@ -62,13 +62,16 @@ def test_transform_rfft_roundtrip(batch_shape: tuple[int], rfft_num: int, use_to np.testing.assert_allclose(z, x) -def test_evaluate_log_prob_rfft2(batch_shape: tuple[int], rfft2_shape: int, use_torch: bool) \ - -> None: - xs = coordgrid(*(np.linspace(0, 1, size, endpoint=False) for size in rfft2_shape)) - kernel = kernels.ExpQuadKernel(np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1) \ - + kernels.DiagonalKernel(1e-2, 1) +@pytest.mark.parametrize("kernel", [ + kernels.ExpQuadKernel(np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1.2), + kernels.MaternKernel(1.5, np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1.3), + kernels.MaternKernel(2.5, np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1.4), +]) +def test_evaluate_log_prob_rfft2(kernel: kernels.Kernel, batch_shape: tuple[int], rfft2_shape: int, + use_torch: bool) -> None: + xs = coordgrid(*(np.linspace(0, kernel.period, size, endpoint=False) for size in rfft2_shape)) + kernel = kernel + kernels.DiagonalKernel(1e-2, kernel.period) cov = kernel.evaluate(xs) - np.fill_diagonal(cov, np.diag(cov) + 1e-6) loc = np.random.normal(0, 1, xs.shape[0]) dist = stats.multivariate_normal(loc, cov) y = dist.rvs(batch_shape)