diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py index 6398596..43f259b 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py @@ -62,7 +62,7 @@ def extract_parameters(self, h: torch.Tensor): log_w = log_softmax(w_pre, dim=-1) return a, b, w, log_w - def forward_1d(self, x, h): + def forward_1d(self, x, h, eps: float = 1e-6): """ x.shape = (batch_size,) h.shape = (batch_size, hidden_size * 3) @@ -73,8 +73,9 @@ def forward_1d(self, x, h): w.shape = (batch_size, hidden_size) """ a, b, w, log_w = self.extract_parameters(h) - c = torch.sigmoid(a * x[:, None] + b) # (batch_size, n_hidden) - d = torch.einsum('...i,...i->...', w, c) # Softmax weighing -> (batch_size,) + c = a * x[:, None] + b # (batch_size, n_hidden) + d = torch.clip(torch.einsum('...i,...i->...', w, torch.sigmoid(c)), eps, + 1 - eps) # Softmax weighing -> (batch_size,) x = inverse_sigmoid(d) # Inverse sigmoid ... (batch_size,) log_t1 = (torch.log(d) - torch.log(1 - d))[:, None] # (batch_size, hidden_size) @@ -162,7 +163,7 @@ def extract_parameters(self, h: torch.Tensor): log_w = log_softmax(w_pre, dim=-1) return a, b, w, log_w, u, log_u - def forward_1d(self, x, h): + def forward_1d(self, x, h, eps: float = 1e-6): # Compute y = inv_sig(w @ sig(a * u @ x + b)) # h.shape = (batch_size, n_parameters) @@ -183,7 +184,11 @@ def forward_1d(self, x, h): ux = torch.einsum('beoi,bei->beo', u, x) # (batch_size, event_size, output_size) c = a * ux + b # (batch_size, event_size, output_size) - d = torch.einsum('beij,bej->bei', w, torch.sigmoid(c)) # Softmax -> (batch_size, event_size, output_size) + d = torch.clip( + torch.einsum('beij,bej->bei', w, torch.sigmoid(c)), + eps, + 1 - eps + ) # Softmax -> (batch_size, event_size, output_size) x = inverse_sigmoid(d) # Inverse sigmoid (batch_size, event_size, output_size) # The problem with NAF: we map c to sigmoid(c), alter it a bit, then map it back with the inverse. # Precision gets lost when mapping back. diff --git a/test/test_sigmoid_transformer.py b/test/test_sigmoid_transformer.py index bba232f..9c5599f 100644 --- a/test/test_sigmoid_transformer.py +++ b/test/test_sigmoid_transformer.py @@ -78,7 +78,7 @@ def test_deep_sigmoid_coupling(event_shape, batch_shape): forward_layer = DSCoupling(torch.Size(event_shape)) inverse_layer = invert(DSCoupling(torch.Size(event_shape))) - x = torch.randn(size=(*batch_shape, *event_shape)) + x = torch.randn(size=(*batch_shape, *event_shape)) # Reduce magnitude for stability y, log_det_forward = forward_layer.forward(x) assert y.shape == x.shape @@ -99,8 +99,18 @@ def test_deep_sigmoid_coupling(event_shape, batch_shape): assert torch.all(~torch.isinf(log_det_inverse)) -@pytest.mark.parametrize('batch_shape', [(7,), (25,), (13,), (2, 37)]) -@pytest.mark.parametrize('n_dim', [2, 5, 100, 1000]) +@pytest.mark.parametrize('batch_shape', [ + (2, 37), + (7,), + (13,), + (25,), +]) +@pytest.mark.parametrize('n_dim', [ + 1000, + 2, + 5, + 100, +]) def test_deep_sigmoid_coupling_flow(n_dim, batch_shape): torch.manual_seed(0)