From 4a042dea961fc70e9476c5818864841352730fab Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 19 Oct 2023 19:26:08 +0200 Subject: [PATCH] Add support for >1 sample in hutchinson power series trace estimation --- .../finite/residual/log_abs_det_estimators.py | 46 +++++++++--- test/test_stochastic_log_det_estimation.py | 70 ++++++++++--------- 2 files changed, 76 insertions(+), 40 deletions(-) diff --git a/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py b/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py index 835ed15..5be083e 100644 --- a/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py +++ b/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py @@ -4,18 +4,47 @@ from normalizing_flows.utils import Geometric, vjp_tensor -def hutchinson_log_abs_det_estimator(g: callable, x: torch.Tensor, noise: torch.Tensor, training: bool, +def hutchinson_log_abs_det_estimator(g: callable, + x: torch.Tensor, + noise: torch.Tensor, + training: bool, n_iterations: int = 8): # f(x) = x + g(x) - w = noise - log_abs_det_jac_f = 0.0 + # x.shape == (batch_size, event_size) + # noise.shape == (batch_size, event_size, n_hutchinson_samples) + # g(x).shape == (batch_size, event_size) + + assert len(noise.shape) == 3 + batch_size, event_size, n_hutchinson_samples = noise.shape + assert len(x.shape) == 2 + assert x.shape == (batch_size, event_size) + assert n_iterations >= 2 + + w = noise # (batch_size, event_size, n_hutchinson_samples) + log_abs_det_jac_f = torch.zeros(size=(batch_size, n_hutchinson_samples)) # this will be averaged at the end + g_value = None for k in range(1, n_iterations + 1): - g_value, w = torch.autograd.functional.vjp(g, x, w) - log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=-1) + # Compute VJP, reshape appropriately for hutchinson averaging + tmp = [torch.autograd.functional.vjp(g, x, w[..., i], strict=True) for i in range(n_hutchinson_samples)] + gs, ws = zip(*tmp) + + if g_value is None: + gs = list(gs) + g_value = gs[0] + + ws = list(ws) + # (batch_size, event_size, n_hutchinson_samples) + w = torch.concatenate([ws[i][..., None] for i in range(n_hutchinson_samples)], dim=2) + log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=1) # sum over event dim + assert log_abs_det_jac_f.shape == (batch_size, n_hutchinson_samples) + log_abs_det_jac_f = torch.mean(log_abs_det_jac_f, dim=1) # hutchinson averaging over the many different series return g_value, log_abs_det_jac_f -def neumann_log_abs_det_estimator(g: callable, x: torch.Tensor, noise: torch.Tensor, training: bool, +def neumann_log_abs_det_estimator(g: callable, + x: torch.Tensor, + noise: torch.Tensor, + training: bool, p: float = 0.5): """ Estimate log[abs(det(grad(f)))](x) with a roulette approach, where f(x) = x + g(x); Lip(g) < 1. @@ -124,8 +153,9 @@ def log_det_roulette(g: nn.Module, x: torch.Tensor, training: bool = False, p: f ) -def log_det_hutchinson(g: nn.Module, x: torch.Tensor, training: bool = False, n_iterations: int = 8): - noise = torch.randn_like(x) +def log_det_hutchinson(g: nn.Module, x: torch.Tensor, training: bool = False, n_iterations: int = 8, + n_hutchinson_samples: int = 1): + noise = torch.randn(size=(*x.shape, n_hutchinson_samples)) return LogDeterminantEstimator.apply( lambda *args, **kwargs: hutchinson_log_abs_det_estimator(*args, **kwargs, n_iterations=n_iterations), g, diff --git a/test/test_stochastic_log_det_estimation.py b/test/test_stochastic_log_det_estimation.py index 124cbc2..08a0ff6 100644 --- a/test/test_stochastic_log_det_estimation.py +++ b/test/test_stochastic_log_det_estimation.py @@ -5,37 +5,54 @@ from normalizing_flows.bijections.finite.residual.log_abs_det_estimators import log_det_hutchinson, log_det_roulette -@pytest.mark.parametrize('n_iterations', [4, 10, 25, 100]) -def test_hutchinson(n_iterations): - # an example of a Lipschitz continuous function with constant < 1: g(x) = 1/2 * x - - n_data = 100 - n_dim = 30 +class LipschitzTestData: + def __init__(self, n_dim): + self.n_dim = n_dim - class TestFunction(nn.Module): + class LipschitzFunction(nn.Module): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, inputs): return 0.5 * inputs - def jac_f(inputs): - return torch.eye(n_dim) * 1.5 + def jac_f(self, _): + return torch.eye(self.n_dim) * (1 + 0.5) - def log_det_jac_f(inputs): - return torch.log(torch.abs(torch.det(jac_f(inputs)))) + def log_det_jac_f(self, inputs): + return torch.log(torch.abs(torch.det(self.jac_f(inputs)))) - g = TestFunction() + +@pytest.mark.parametrize('n_hutchinson_samples', [*list(range(25, 40))]) +@pytest.mark.parametrize('n_iterations', [4, 10, 25, 100]) +def test_hutchinson(n_iterations, n_hutchinson_samples): + # This test checks for validity of the hutchinson power series trace estimator. + # The estimator computes log|det(Jac_f)| where f(x) = x + g(x) and x is Lipschitz continuous with Lip(g) < 1. + # In this example: a Lipschitz continuous function with constant < 1 is g(x) = 1/2 * x; Lip(g) = 1/2. + + # The reference jacobian of f is I * 1.5, because d/dx f(x) = d/dx x + g(x) = d/dx x + 1/2 * x = 1 + 1/2 = 1.5 + + n_data = 1 + n_dim = 1 + + test_data = LipschitzTestData(n_dim) + g = test_data.LipschitzFunction() torch.manual_seed(0) x = torch.randn(size=(n_data, n_dim)) - g_value, log_det_f = log_det_hutchinson(g, x, training=False, n_iterations=n_iterations) - log_det_f_true = log_det_jac_f(x).ravel() - - print(f'{log_det_f = }') + g_value, log_det_f_estimated = log_det_hutchinson( + g, + x, + training=False, + n_iterations=n_iterations, + n_hutchinson_samples=n_hutchinson_samples + ) + log_det_f_true = test_data.log_det_jac_f(x).ravel() + + print() + print(f'{log_det_f_estimated = }') print(f'{log_det_f_true = }') - print(f'{log_det_f.mean() = }') - assert torch.allclose(log_det_f, log_det_f_true) + assert torch.allclose(log_det_f_estimated, log_det_f_true) @pytest.mark.parametrize('p', [0.01, 0.1, 0.5, 0.9, 0.99]) @@ -45,25 +62,14 @@ def test_roulette(p): n_data = 100 n_dim = 30 - class TestFunction(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, inputs): - return 0.5 * inputs - - def jac_f(inputs): - return torch.eye(n_dim) * 1.5 - - def log_det_jac_f(inputs): - return torch.log(torch.abs(torch.det(jac_f(inputs)))) + test_data = LipschitzTestData(n_dim) - g = TestFunction() + g = test_data.LipschitzFunction() torch.manual_seed(0) x = torch.randn(size=(n_data, n_dim)) g_value, log_det_f = log_det_roulette(g, x, training=False, p=p) - log_det_f_true = log_det_jac_f(x).ravel() + log_det_f_true = test_data.log_det_jac_f(x).ravel() print(f'{log_det_f = }') print(f'{log_det_f_true = }')