Skip to content

Commit

Permalink
Fixing DenseSigmoid reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 13, 2023
1 parent 3a8298e commit 8ee40d0
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,24 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch


class DenseSigmoidInnerTransform(nn.Module):
def __init__(self, input_size, output_size, min_scale: float = 1e-3):
"""
The function inv_sigmoid(w @ sigmoid(a * u @ x + b)).
Shapes:
x.shape = (batch_size, event_size, input_size)
u.shape = (batch_size, event_size, output_size, input_size)
a.shape = (batch_size, event_size, output_size)
b.shape = (batch_size, event_size, output_size)
w.shape = (batch_size, event_size, output_size, output_size)
Because this function is applied to vectors and not scalars, we cannot merge batch & event sizes into a single dim.
"""

def __init__(self, input_size, output_size, event_size, min_scale: float = 1e-3):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.event_size = event_size

self.min_scale = min_scale
self.const = 1000
Expand All @@ -126,18 +140,16 @@ def n_parameters(self):

def extract_parameters(self, h: torch.Tensor):
"""
h.shape = (batch_size, self.n_parameters)
h.shape = (batch_size, event_size, self.n_parameters)
"""
assert len(h.shape) == 2
batch_size = len(h)

da = h[:, :self.output_size]
db = h[:, self.output_size:self.output_size * 2]
dw = h[:, self.output_size * 2:self.output_size * 2 + self.output_size ** 2]
du = h[:, self.output_size * 2 + self.output_size ** 2:]
da = h[..., :self.output_size]
db = h[..., self.output_size:self.output_size * 2]
dw = h[..., self.output_size * 2:self.output_size * 2 + self.output_size ** 2]
du = h[..., self.output_size * 2 + self.output_size ** 2:]

du = du.view(batch_size, self.output_size, self.input_size)
dw = dw.view(batch_size, self.output_size, self.output_size)
du = du.view(*h.shape[:2], self.output_size, self.input_size)
dw = dw.view(*h.shape[:2], self.output_size, self.output_size)

u_pre = 0.0 + du / self.const
u = torch.softmax(u_pre, dim=-1)
Expand All @@ -155,51 +167,64 @@ def forward_1d(self, x, h):
# h.shape = (batch_size, n_parameters)

# Within the function:
# x.shape = (batch_size, input_size)
# a.shape = (batch_size, output_size)
# b.shape = (batch_size, output_size)
# w.shape = (batch_size, output_size, output_size)
# log_w.shape = (batch_size, output_size, output_size)
# u.shape = (batch_size, output_size, input_size)
# log_u.shape = (batch_size, output_size, input_size)
# x.shape = (batch_size, event_size, input_size)
# a.shape = (batch_size, event_size, output_size)
# b.shape = (batch_size, event_size, output_size)
# w.shape = (batch_size, event_size, output_size, output_size)
# log_w.shape = (batch_size, event_size, output_size, output_size)
# u.shape = (batch_size, event_size, output_size, input_size)
# log_u.shape = (batch_size, event_size, output_size, input_size)

# Return
# y.shape = (batch_size, output_size)
# y.shape = (batch_size, event_size, output_size)
# log_det.shape = (batch_size,)

a, b, w, log_w, u, log_u = self.extract_parameters(h)

ux = torch.einsum('boi,bi->bo', u, x) # (batch_size, output_size)
c = a * ux + b # (batch_size, output_size)
d = torch.einsum('bij,bj->bi', w, torch.sigmoid(c)) # Softmax weighing -> (batch_size, output_size)
x = inverse_sigmoid(d) # Inverse sigmoid (batch_size, output_size)
ux = torch.einsum('beoi,bei->beo', u, x) # (batch_size, event_size, output_size)
c = a * ux + b # (batch_size, event_size, output_size)
d = torch.einsum('beij,bej->bei', w, torch.sigmoid(c)) # Softmax -> (batch_size, event_size, output_size)
x = inverse_sigmoid(d) # Inverse sigmoid (batch_size, event_size, output_size)
# The problem with NAF: we map c to sigmoid(c), alter it a bit, then map it back with the inverse.
# Precision gets lost when mapping back.

log_t1 = (torch.log(d) - torch.log(1 - d))[:, :, None] # (batch_size, output_size, 1)
log_t2 = log_w # (batch_size, output_size, output_size)
log_t3 = (log_sigmoid(c) + log_sigmoid(-c))[:, None, :] # (batch_size, 1, output_size)
log_t4 = torch.log(a)[:, None, :] # (batch_size, 1, output_size)
d_ = d.flatten(0, 1)
c_ = c.flatten(0, 1)
a_ = a.flatten(0, 1)
log_w_ = log_w.flatten(0, 1)
log_u_ = log_u.flatten(0, 1)

log_t1 = (torch.log(d_) - torch.log(1 - d_))[:, :, None] # (batch_size, output_size, 1)
log_t2 = log_w_ # (batch_size, output_size, output_size)
log_t3 = (log_sigmoid(c_) + log_sigmoid(-c_))[:, None, :] # (batch_size, 1, output_size)
log_t4 = torch.log(a_)[:, None, :] # (batch_size, 1, output_size)

m1 = (log_t1 + log_t2 + log_t3 + log_t4)[:, :, :, None] # (batch_size, output_size, output_size, 1)
m2 = log_u[:, None, :, :] # (batch_size, 1, output_size, input_size)
m2 = log_u_[:, None, :, :] # (batch_size, 1, output_size, input_size)
log_det = torch.sum(log_dot(m1, m2), dim=(1, 2, 3)) # (batch_size,)

z = x
return z, log_det
return z, log_det.view(*x.shape[:2])


class DenseSigmoid(Transformer):
"""
Apply y = f1 \\circ f2 \\circ ... \\circ fn (x) where
* f1 is a dense sigmoid inner transform which maps from 1 to h dimensions;
* fn is a dense sigmoid inner transform which maps from h to 1 dimensions;
* fi (for all other i) is a dense sigmoid inner transform which maps from h to h dimensions.
"""

def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
n_dense_layers: int = 1,
hidden_size: int = 8):
super().__init__(event_shape)
self.n_dense_layers = n_dense_layers
layers = [
DenseSigmoidInnerTransform(self.n_dim, hidden_size),
*[DenseSigmoidInnerTransform(hidden_size, hidden_size) for _ in range(n_dense_layers)],
DenseSigmoidInnerTransform(hidden_size, self.n_dim),
DenseSigmoidInnerTransform(1, hidden_size, self.n_dim),
*[DenseSigmoidInnerTransform(hidden_size, hidden_size, self.n_dim) for _ in range(n_dense_layers)],
DenseSigmoidInnerTransform(hidden_size, 1, self.n_dim),
]
self.layers = nn.ModuleList(layers)

Expand All @@ -219,28 +244,22 @@ def split_parameters(self, h):

def forward_1d(self, x_flat, h_split_flat: List[torch.Tensor]):
log_det_flat = None
x_flat = x_flat[..., None]
for i in range(len(self.layers)):
x_flat, log_det_flat_inc = self.layers[i].forward_1d(x_flat, h_split_flat[i])
if log_det_flat is None:
log_det_flat = log_det_flat_inc
else:
log_det_flat += log_det_flat_inc
z_flat = x_flat
z_flat = x_flat[..., 0]
log_det_flat = log_det_flat.sum(dim=1) # Sum over event
return z_flat, log_det_flat

def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# x.shape == (*batch_shape, *event_shape)
# h.shape == (*batch_shape, n_parameters)
# h.shape == (*batch_shape, *event_shape, n_parameters)
h_split = self.split_parameters(h)
event_size = self.n_dim
batch_size = int(torch.prod(torch.as_tensor(get_batch_shape(x, self.event_shape))))
x_flat = x.view(batch_size, event_size)
h_split_flat = [h_split[i].view(batch_size, -1) for i in range(len(h_split))]

z_flat, log_det_flat = self.forward_1d(x_flat, h_split_flat)

log_det = log_det_flat.view(batch_size)
z = z_flat.view_as(x)
z, log_det = self.forward_1d(x, h_split)
return z, log_det

def inverse_1d(self, z, h):
Expand All @@ -251,18 +270,10 @@ def f(inputs):
return x, log_det

def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# x.shape == (*batch_shape, *event_shape)
# h.shape == (*batch_shape, n_parameters)
# z.shape == (*batch_shape, *event_shape)
# h.shape == (*batch_shape, *event_shape, n_parameters)
h_split = self.split_parameters(h)
event_size = self.n_dim
batch_size = int(torch.prod(torch.as_tensor(get_batch_shape(z, self.event_shape))))
z_flat = z.view(batch_size, event_size)
h_split_flat = [h_split[i].view(batch_size, -1) for i in range(len(h_split))]

x_flat, log_det_flat = self.inverse_1d(z_flat, h_split_flat)

log_det = log_det_flat.view(batch_size)
z = x_flat.view_as(z)
x, log_det = self.inverse_1d(z, h_split)
return z, log_det


Expand Down
11 changes: 4 additions & 7 deletions test/test_reconstruction_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,12 @@
from test.constants import __test_constants


def setup_transformer_data(transformer_class: Transformer, batch_shape, event_shape, vector_to_vector: bool = False):
def setup_transformer_data(transformer_class: Transformer, batch_shape, event_shape):
# vector_to_vector: does the transformer map a vector to vector? Otherwise, it maps a scalar to scalar.
torch.manual_seed(0)
transformer = transformer_class(event_shape)
x = torch.randn(*batch_shape, *event_shape)
if vector_to_vector:
h = torch.randn(*batch_shape, transformer.n_parameters)
else:
h = torch.randn(*batch_shape, *event_shape, transformer.n_parameters)
h = torch.randn(*batch_shape, *event_shape, transformer.n_parameters)
return transformer, x, h


Expand Down Expand Up @@ -110,5 +107,5 @@ def test_combination_basic(transformer_class: Transformer, batch_shape: Tuple, e
@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape'])
@pytest.mark.parametrize('event_shape', __test_constants['event_shape'])
def test_combination_vector_to_vector(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple):
transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape, vector_to_vector=True)
assert_valid_reconstruction(transformer, x, h)
transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape)
assert_valid_reconstruction(transformer, x, h, reconstruction_eps=1e-2)

0 comments on commit 8ee40d0

Please sign in to comment.