Skip to content

Commit

Permalink
Add comment regarding stochastic trace estimation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 19, 2023
1 parent 64ddd42 commit 8620d19
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions test/test_stochastic_log_det_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ def log_det_jac_f(self, inputs):

@pytest.mark.parametrize('n_hutchinson_samples', [*list(range(25, 40))])
@pytest.mark.parametrize('n_iterations', [4, 10, 25, 100])
def test_hutchinson(n_iterations, n_hutchinson_samples):
def test_hutchinson_power_series(n_iterations, n_hutchinson_samples):
# This test checks for validity of the hutchinson power series trace estimator.
# The estimator computes log|det(Jac_f)| where f(x) = x + g(x) and x is Lipschitz continuous with Lip(g) < 1.
# In this example: a Lipschitz continuous function with constant < 1 is g(x) = 1/2 * x; Lip(g) = 1/2.

# The reference jacobian of f is I * 1.5, because d/dx f(x) = d/dx x + g(x) = d/dx x + 1/2 * x = 1 + 1/2 = 1.5

# TODO: use the analytical variance of the Monte Carlo Hutchinson trace estimator to compute the variance of the
# Hutchinson power series estimator. Then make sure that the power series error is below 4 * variance.

n_data = 1
n_dim = 1

Expand All @@ -56,7 +59,7 @@ def test_hutchinson(n_iterations, n_hutchinson_samples):


@pytest.mark.parametrize('p', [0.01, 0.1, 0.5, 0.9, 0.99])
def test_roulette(p):
def test_roulette_power_series(p):
# an example of a Lipschitz continuous function with constant < 1: g(x) = 1/2 * x

n_data = 100
Expand Down

0 comments on commit 8620d19

Please sign in to comment.