Skip to content

Commit

Permalink
Update nugget variance for rfft2 tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Dec 16, 2022
1 parent f3aa65a commit eaa50d0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions gptools-util/tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_evaluate_log_prob_rfft2(batch_shape: tuple[int], rfft2_shape: int, use_
kernel = kernels.ExpQuadKernel(np.random.gamma(10, 0.1), np.random.gamma(10, 0.01), 1) \
+ kernels.DiagonalKernel(1e-2, 1)
cov = kernel.evaluate(xs)
np.fill_diagonal(cov, np.diag(cov) + 1e-9)
np.fill_diagonal(cov, np.diag(cov) + 1e-6)
loc = np.random.normal(0, 1, xs.shape[0])
dist = stats.multivariate_normal(loc, cov)
y = dist.rvs(batch_shape)
Expand All @@ -80,7 +80,7 @@ def test_evaluate_log_prob_rfft2(batch_shape: tuple[int], rfft2_shape: int, use_
y2 = th.as_tensor(y2)
loc2 = th.as_tensor(loc2)
cov2 = th.as_tensor(cov2)
log_prob_rfft2 = fft.evaluate_log_prob_rfft2(y2, loc2, cov=cov2) + 1e-9
log_prob_rfft2 = fft.evaluate_log_prob_rfft2(y2, loc2, cov=cov2)
np.testing.assert_allclose(log_prob, log_prob_rfft2)


Expand Down

0 comments on commit eaa50d0

Please sign in to comment.