Skip to content

Commit

Permalink
Run more general tests for evaluate_log_prob_rfft2.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Dec 16, 2022
1 parent eaa50d0 commit d472c00
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions gptools-util/tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,16 @@ def test_transform_rfft_roundtrip(batch_shape: tuple[int], rfft_num: int, use_to
np.testing.assert_allclose(z, x)


def test_evaluate_log_prob_rfft2(batch_shape: tuple[int], rfft2_shape: int, use_torch: bool) \
-> None:
xs = coordgrid(*(np.linspace(0, 1, size, endpoint=False) for size in rfft2_shape))
kernel = kernels.ExpQuadKernel(np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1) \
+ kernels.DiagonalKernel(1e-2, 1)
@pytest.mark.parametrize("kernel", [
kernels.ExpQuadKernel(np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1.2),
kernels.MaternKernel(1.5, np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1.3),
kernels.MaternKernel(2.5, np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1.4),
])
def test_evaluate_log_prob_rfft2(kernel: kernels.Kernel, batch_shape: tuple[int], rfft2_shape: int,
use_torch: bool) -> None:
xs = coordgrid(*(np.linspace(0, kernel.period, size, endpoint=False) for size in rfft2_shape))
kernel = kernel + kernels.DiagonalKernel(1e-2, kernel.period)
cov = kernel.evaluate(xs)
np.fill_diagonal(cov, np.diag(cov) + 1e-6)
loc = np.random.normal(0, 1, xs.shape[0])
dist = stats.multivariate_normal(loc, cov)
y = dist.rvs(batch_shape)
Expand Down

0 comments on commit d472c00

Please sign in to comment.