Skip to content

Commit

Permalink
Add support for >1 sample in hutchinson power series trace estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 19, 2023
1 parent 53397ea commit 4a042de
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,47 @@
from normalizing_flows.utils import Geometric, vjp_tensor


def hutchinson_log_abs_det_estimator(g: callable, x: torch.Tensor, noise: torch.Tensor, training: bool,
def hutchinson_log_abs_det_estimator(g: callable,
x: torch.Tensor,
noise: torch.Tensor,
training: bool,
n_iterations: int = 8):
# f(x) = x + g(x)
w = noise
log_abs_det_jac_f = 0.0
# x.shape == (batch_size, event_size)
# noise.shape == (batch_size, event_size, n_hutchinson_samples)
# g(x).shape == (batch_size, event_size)

assert len(noise.shape) == 3
batch_size, event_size, n_hutchinson_samples = noise.shape
assert len(x.shape) == 2
assert x.shape == (batch_size, event_size)
assert n_iterations >= 2

w = noise # (batch_size, event_size, n_hutchinson_samples)
log_abs_det_jac_f = torch.zeros(size=(batch_size, n_hutchinson_samples)) # this will be averaged at the end
g_value = None
for k in range(1, n_iterations + 1):
g_value, w = torch.autograd.functional.vjp(g, x, w)
log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=-1)
# Compute VJP, reshape appropriately for hutchinson averaging
tmp = [torch.autograd.functional.vjp(g, x, w[..., i], strict=True) for i in range(n_hutchinson_samples)]
gs, ws = zip(*tmp)

if g_value is None:
gs = list(gs)
g_value = gs[0]

ws = list(ws)
# (batch_size, event_size, n_hutchinson_samples)
w = torch.concatenate([ws[i][..., None] for i in range(n_hutchinson_samples)], dim=2)
log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=1) # sum over event dim
assert log_abs_det_jac_f.shape == (batch_size, n_hutchinson_samples)
log_abs_det_jac_f = torch.mean(log_abs_det_jac_f, dim=1) # hutchinson averaging over the many different series
return g_value, log_abs_det_jac_f


def neumann_log_abs_det_estimator(g: callable, x: torch.Tensor, noise: torch.Tensor, training: bool,
def neumann_log_abs_det_estimator(g: callable,
x: torch.Tensor,
noise: torch.Tensor,
training: bool,
p: float = 0.5):
"""
Estimate log[abs(det(grad(f)))](x) with a roulette approach, where f(x) = x + g(x); Lip(g) < 1.
Expand Down Expand Up @@ -124,8 +153,9 @@ def log_det_roulette(g: nn.Module, x: torch.Tensor, training: bool = False, p: f
)


def log_det_hutchinson(g: nn.Module, x: torch.Tensor, training: bool = False, n_iterations: int = 8):
noise = torch.randn_like(x)
def log_det_hutchinson(g: nn.Module, x: torch.Tensor, training: bool = False, n_iterations: int = 8,
n_hutchinson_samples: int = 1):
noise = torch.randn(size=(*x.shape, n_hutchinson_samples))
return LogDeterminantEstimator.apply(
lambda *args, **kwargs: hutchinson_log_abs_det_estimator(*args, **kwargs, n_iterations=n_iterations),
g,
Expand Down
70 changes: 38 additions & 32 deletions test/test_stochastic_log_det_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,54 @@
from normalizing_flows.bijections.finite.residual.log_abs_det_estimators import log_det_hutchinson, log_det_roulette


@pytest.mark.parametrize('n_iterations', [4, 10, 25, 100])
def test_hutchinson(n_iterations):
# an example of a Lipschitz continuous function with constant < 1: g(x) = 1/2 * x

n_data = 100
n_dim = 30
class LipschitzTestData:
def __init__(self, n_dim):
self.n_dim = n_dim

class TestFunction(nn.Module):
class LipschitzFunction(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def forward(self, inputs):
return 0.5 * inputs

def jac_f(inputs):
return torch.eye(n_dim) * 1.5
def jac_f(self, _):
return torch.eye(self.n_dim) * (1 + 0.5)

def log_det_jac_f(inputs):
return torch.log(torch.abs(torch.det(jac_f(inputs))))
def log_det_jac_f(self, inputs):
return torch.log(torch.abs(torch.det(self.jac_f(inputs))))

g = TestFunction()

@pytest.mark.parametrize('n_hutchinson_samples', [*list(range(25, 40))])
@pytest.mark.parametrize('n_iterations', [4, 10, 25, 100])
def test_hutchinson(n_iterations, n_hutchinson_samples):
# This test checks for validity of the hutchinson power series trace estimator.
# The estimator computes log|det(Jac_f)| where f(x) = x + g(x) and x is Lipschitz continuous with Lip(g) < 1.
# In this example: a Lipschitz continuous function with constant < 1 is g(x) = 1/2 * x; Lip(g) = 1/2.

# The reference jacobian of f is I * 1.5, because d/dx f(x) = d/dx x + g(x) = d/dx x + 1/2 * x = 1 + 1/2 = 1.5

n_data = 1
n_dim = 1

test_data = LipschitzTestData(n_dim)
g = test_data.LipschitzFunction()

torch.manual_seed(0)
x = torch.randn(size=(n_data, n_dim))
g_value, log_det_f = log_det_hutchinson(g, x, training=False, n_iterations=n_iterations)
log_det_f_true = log_det_jac_f(x).ravel()

print(f'{log_det_f = }')
g_value, log_det_f_estimated = log_det_hutchinson(
g,
x,
training=False,
n_iterations=n_iterations,
n_hutchinson_samples=n_hutchinson_samples
)
log_det_f_true = test_data.log_det_jac_f(x).ravel()

print()
print(f'{log_det_f_estimated = }')
print(f'{log_det_f_true = }')
print(f'{log_det_f.mean() = }')
assert torch.allclose(log_det_f, log_det_f_true)
assert torch.allclose(log_det_f_estimated, log_det_f_true)


@pytest.mark.parametrize('p', [0.01, 0.1, 0.5, 0.9, 0.99])
Expand All @@ -45,25 +62,14 @@ def test_roulette(p):
n_data = 100
n_dim = 30

class TestFunction(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def forward(self, inputs):
return 0.5 * inputs

def jac_f(inputs):
return torch.eye(n_dim) * 1.5

def log_det_jac_f(inputs):
return torch.log(torch.abs(torch.det(jac_f(inputs))))
test_data = LipschitzTestData(n_dim)

g = TestFunction()
g = test_data.LipschitzFunction()

torch.manual_seed(0)
x = torch.randn(size=(n_data, n_dim))
g_value, log_det_f = log_det_roulette(g, x, training=False, p=p)
log_det_f_true = log_det_jac_f(x).ravel()
log_det_f_true = test_data.log_det_jac_f(x).ravel()

print(f'{log_det_f = }')
print(f'{log_det_f_true = }')
Expand Down

0 comments on commit 4a042de

Please sign in to comment.