diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py index 314cbb2..e98260d 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py @@ -11,9 +11,12 @@ # As defined in the NAF paper +# Note: when using dense sigmoid transformers, the input data should have relatively small magnitudes (abs(x) < 10). +# Otherwise, inversion becomes unstable due to sigmoid saturation. def inverse_sigmoid(p): - return torch.log(p / (1 - p)) + # return torch.log(p / (1 - p)) + return torch.log(p) - torch.log1p(-p) class Sigmoid(Transformer): @@ -87,7 +90,7 @@ def inverse_1d(self, z, h): def f(inputs): return self.forward_1d(inputs, h) - x, log_det = bisection_no_gradient(f, z) + x, log_det = bisection_no_gradient(f, z, a=-10.0, b=10.0) return x, log_det def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -167,9 +170,11 @@ def forward_1d(self, x, h): a, b, w, log_w, u, log_u = self.extract_parameters(h) ux = torch.einsum('boi,bi->bo', u, x) # (batch_size, output_size) - c = torch.sigmoid(a * ux + b) # (batch_size, output_size) - d = torch.einsum('bij,bj->bi', w, c) # Softmax weighing -> (batch_size, output_size) + c = a * ux + b # (batch_size, output_size) + d = torch.einsum('bij,bj->bi', w, torch.sigmoid(c)) # Softmax weighing -> (batch_size, output_size) x = inverse_sigmoid(d) # Inverse sigmoid (batch_size, output_size) + # The problem with NAF: we map c to sigmoid(c), alter it a bit, then map it back with the inverse. + # Precision gets lost when mapping back. log_t1 = (torch.log(d) - torch.log(1 - d))[:, :, None] # (batch_size, output_size, 1) log_t2 = log_w # (batch_size, output_size, output_size) @@ -188,7 +193,7 @@ class DenseSigmoid(Transformer): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_dense_layers: int = 1, - hidden_size: int = 30): + hidden_size: int = 8): super().__init__(event_shape) self.n_dense_layers = n_dense_layers layers = [ @@ -242,7 +247,7 @@ def inverse_1d(self, z, h): def f(inputs): return self.forward_1d(inputs, h) - x, log_det = bisection_no_gradient(f, z) + x, log_det = bisection_no_gradient(f, z, a=-10.0, b=10.0) return x, log_det def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/normalizing_flows/bijections/numerical_inversion.py b/normalizing_flows/bijections/numerical_inversion.py index 8935e59..9a642c8 100644 --- a/normalizing_flows/bijections/numerical_inversion.py +++ b/normalizing_flows/bijections/numerical_inversion.py @@ -38,8 +38,8 @@ def bisection(f, y, a, b, n, h): def bisection_no_gradient(f: callable, y: torch.Tensor, - a: torch.Tensor = None, - b: torch.Tensor = None, + a: Union[torch.Tensor, float] = None, + b: Union[torch.Tensor, float] = None, n_iterations: int = 500, atol: float = 1e-9): """ @@ -53,8 +53,13 @@ def bisection_no_gradient(f: callable, if a is None: a = torch.full_like(y, fill_value=-100.0) + elif isinstance(a, float): + a = torch.full_like(y, fill_value=a) + if b is None: b = torch.full_like(y, fill_value=100.0) + elif isinstance(b, float): + b = torch.full_like(y, fill_value=b) c = (a + b) / 2 log_det = None diff --git a/test/test_utils.py b/test/test_utils.py index 66cc330..30ffb64 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -3,6 +3,8 @@ import pytest import torch +from normalizing_flows.bijections.finite.autoregressive.transformers.combination.sigmoid import inverse_sigmoid +from normalizing_flows.bijections.finite.autoregressive.transformers.combination.sigmoid_util import log_softmax from normalizing_flows.utils import get_batch_shape, vjp_tensor @@ -59,3 +61,21 @@ def test_vjp_tensor_batched_quadratic(): fval, vjp = torch.autograd.functional.vjp(lambda _in: _in ** 2, x, v) assert torch.allclose(vjp, 2 * x * v) + + +def test_log_softmax(): + torch.manual_seed(0) + x_pre = torch.randn(5, 10) + x = torch.softmax(x_pre, dim=1) + x_log_1 = log_softmax(x_pre, dim=1) + x_log_2 = torch.log(x) + + assert torch.allclose(x_log_1, x_log_2) + + +def test_inverse_sigmoid(): + torch.manual_seed(0) + x = torch.randn(10) + s = torch.sigmoid(x) + xr = inverse_sigmoid(s) + assert torch.allclose(x, xr)