Skip to content

Commit

Permalink
Fix stability issue in sigmoid transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 13, 2023
1 parent 98151a4 commit c681f84
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def extract_parameters(self, h: torch.Tensor):
log_w = log_softmax(w_pre, dim=-1)
return a, b, w, log_w

def forward_1d(self, x, h):
def forward_1d(self, x, h, eps: float = 1e-6):
"""
x.shape = (batch_size,)
h.shape = (batch_size, hidden_size * 3)
Expand All @@ -73,8 +73,9 @@ def forward_1d(self, x, h):
w.shape = (batch_size, hidden_size)
"""
a, b, w, log_w = self.extract_parameters(h)
c = torch.sigmoid(a * x[:, None] + b) # (batch_size, n_hidden)
d = torch.einsum('...i,...i->...', w, c) # Softmax weighing -> (batch_size,)
c = a * x[:, None] + b # (batch_size, n_hidden)
d = torch.clip(torch.einsum('...i,...i->...', w, torch.sigmoid(c)), eps,
1 - eps) # Softmax weighing -> (batch_size,)
x = inverse_sigmoid(d) # Inverse sigmoid ... (batch_size,)

log_t1 = (torch.log(d) - torch.log(1 - d))[:, None] # (batch_size, hidden_size)
Expand Down Expand Up @@ -162,7 +163,7 @@ def extract_parameters(self, h: torch.Tensor):
log_w = log_softmax(w_pre, dim=-1)
return a, b, w, log_w, u, log_u

def forward_1d(self, x, h):
def forward_1d(self, x, h, eps: float = 1e-6):
# Compute y = inv_sig(w @ sig(a * u @ x + b))
# h.shape = (batch_size, n_parameters)

Expand All @@ -183,7 +184,11 @@ def forward_1d(self, x, h):

ux = torch.einsum('beoi,bei->beo', u, x) # (batch_size, event_size, output_size)
c = a * ux + b # (batch_size, event_size, output_size)
d = torch.einsum('beij,bej->bei', w, torch.sigmoid(c)) # Softmax -> (batch_size, event_size, output_size)
d = torch.clip(
torch.einsum('beij,bej->bei', w, torch.sigmoid(c)),
eps,
1 - eps
) # Softmax -> (batch_size, event_size, output_size)
x = inverse_sigmoid(d) # Inverse sigmoid (batch_size, event_size, output_size)
# The problem with NAF: we map c to sigmoid(c), alter it a bit, then map it back with the inverse.
# Precision gets lost when mapping back.
Expand Down
16 changes: 13 additions & 3 deletions test/test_sigmoid_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_deep_sigmoid_coupling(event_shape, batch_shape):
forward_layer = DSCoupling(torch.Size(event_shape))
inverse_layer = invert(DSCoupling(torch.Size(event_shape)))

x = torch.randn(size=(*batch_shape, *event_shape))
x = torch.randn(size=(*batch_shape, *event_shape)) # Reduce magnitude for stability
y, log_det_forward = forward_layer.forward(x)

assert y.shape == x.shape
Expand All @@ -99,8 +99,18 @@ def test_deep_sigmoid_coupling(event_shape, batch_shape):
assert torch.all(~torch.isinf(log_det_inverse))


@pytest.mark.parametrize('batch_shape', [(7,), (25,), (13,), (2, 37)])
@pytest.mark.parametrize('n_dim', [2, 5, 100, 1000])
@pytest.mark.parametrize('batch_shape', [
(2, 37),
(7,),
(13,),
(25,),
])
@pytest.mark.parametrize('n_dim', [
1000,
2,
5,
100,
])
def test_deep_sigmoid_coupling_flow(n_dim, batch_shape):
torch.manual_seed(0)

Expand Down

0 comments on commit c681f84

Please sign in to comment.