Skip to content

Commit

Permalink
Towards fixing DDSF
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 13, 2023
1 parent 3e9b483 commit 3a8298e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@


# As defined in the NAF paper
# Note: when using dense sigmoid transformers, the input data should have relatively small magnitudes (abs(x) < 10).
# Otherwise, inversion becomes unstable due to sigmoid saturation.

def inverse_sigmoid(p):
return torch.log(p / (1 - p))
# return torch.log(p / (1 - p))
return torch.log(p) - torch.log1p(-p)


class Sigmoid(Transformer):
Expand Down Expand Up @@ -87,7 +90,7 @@ def inverse_1d(self, z, h):
def f(inputs):
return self.forward_1d(inputs, h)

x, log_det = bisection_no_gradient(f, z)
x, log_det = bisection_no_gradient(f, z, a=-10.0, b=10.0)
return x, log_det

def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -167,9 +170,11 @@ def forward_1d(self, x, h):
a, b, w, log_w, u, log_u = self.extract_parameters(h)

ux = torch.einsum('boi,bi->bo', u, x) # (batch_size, output_size)
c = torch.sigmoid(a * ux + b) # (batch_size, output_size)
d = torch.einsum('bij,bj->bi', w, c) # Softmax weighing -> (batch_size, output_size)
c = a * ux + b # (batch_size, output_size)
d = torch.einsum('bij,bj->bi', w, torch.sigmoid(c)) # Softmax weighing -> (batch_size, output_size)
x = inverse_sigmoid(d) # Inverse sigmoid (batch_size, output_size)
# The problem with NAF: we map c to sigmoid(c), alter it a bit, then map it back with the inverse.
# Precision gets lost when mapping back.

log_t1 = (torch.log(d) - torch.log(1 - d))[:, :, None] # (batch_size, output_size, 1)
log_t2 = log_w # (batch_size, output_size, output_size)
Expand All @@ -188,7 +193,7 @@ class DenseSigmoid(Transformer):
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
n_dense_layers: int = 1,
hidden_size: int = 30):
hidden_size: int = 8):
super().__init__(event_shape)
self.n_dense_layers = n_dense_layers
layers = [
Expand Down Expand Up @@ -242,7 +247,7 @@ def inverse_1d(self, z, h):
def f(inputs):
return self.forward_1d(inputs, h)

x, log_det = bisection_no_gradient(f, z)
x, log_det = bisection_no_gradient(f, z, a=-10.0, b=10.0)
return x, log_det

def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
9 changes: 7 additions & 2 deletions normalizing_flows/bijections/numerical_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def bisection(f, y, a, b, n, h):

def bisection_no_gradient(f: callable,
y: torch.Tensor,
a: torch.Tensor = None,
b: torch.Tensor = None,
a: Union[torch.Tensor, float] = None,
b: Union[torch.Tensor, float] = None,
n_iterations: int = 500,
atol: float = 1e-9):
"""
Expand All @@ -53,8 +53,13 @@ def bisection_no_gradient(f: callable,

if a is None:
a = torch.full_like(y, fill_value=-100.0)
elif isinstance(a, float):
a = torch.full_like(y, fill_value=a)

if b is None:
b = torch.full_like(y, fill_value=100.0)
elif isinstance(b, float):
b = torch.full_like(y, fill_value=b)

c = (a + b) / 2
log_det = None
Expand Down
20 changes: 20 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest
import torch

from normalizing_flows.bijections.finite.autoregressive.transformers.combination.sigmoid import inverse_sigmoid
from normalizing_flows.bijections.finite.autoregressive.transformers.combination.sigmoid_util import log_softmax
from normalizing_flows.utils import get_batch_shape, vjp_tensor


Expand Down Expand Up @@ -59,3 +61,21 @@ def test_vjp_tensor_batched_quadratic():
fval, vjp = torch.autograd.functional.vjp(lambda _in: _in ** 2, x, v)

assert torch.allclose(vjp, 2 * x * v)


def test_log_softmax():
torch.manual_seed(0)
x_pre = torch.randn(5, 10)
x = torch.softmax(x_pre, dim=1)
x_log_1 = log_softmax(x_pre, dim=1)
x_log_2 = torch.log(x)

assert torch.allclose(x_log_1, x_log_2)


def test_inverse_sigmoid():
torch.manual_seed(0)
x = torch.randn(10)
s = torch.sigmoid(x)
xr = inverse_sigmoid(s)
assert torch.allclose(x, xr)

0 comments on commit 3a8298e

Please sign in to comment.