Skip to content

Commit

Permalink
Fix dense sigmoid transformer reconstruction and log determinant
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 13, 2023
1 parent 8ee40d0 commit 98151a4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,25 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch
# x.shape == (*batch_shape, *event_shape)
# h.shape == (*batch_shape, *event_shape, n_parameters)
h_split = self.split_parameters(h)
z, log_det = self.forward_1d(x, h_split)

n_event_dims = len(self.event_shape)
n_batch_dims = len(x.shape) - n_event_dims
batch_shape = x.shape[:n_batch_dims]

batch_size = int(torch.prod(torch.as_tensor(batch_shape)))
event_size = int(torch.prod(torch.as_tensor(self.event_shape)))

# Flatten event and batch
x_flat = x.view(batch_size, event_size)
h_split_flat = [h_element.view(batch_size, event_size, -1) for h_element in h_split]

z_flat, log_det_flat = self.forward_1d(x_flat, h_split_flat)

# Unflatten event and batch
z = z_flat.view(*batch_shape, *self.event_shape)
log_det = log_det_flat.view(*batch_shape)
# log_det = sum_except_batch(log_det, self.event_shape)

return z, log_det

def inverse_1d(self, z, h):
Expand All @@ -273,8 +291,24 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch
# z.shape == (*batch_shape, *event_shape)
# h.shape == (*batch_shape, *event_shape, n_parameters)
h_split = self.split_parameters(h)
x, log_det = self.inverse_1d(z, h_split)
return z, log_det

n_event_dims = len(self.event_shape)
n_batch_dims = len(z.shape) - n_event_dims
batch_shape = z.shape[:n_batch_dims]

batch_size = int(torch.prod(torch.as_tensor(batch_shape)))
event_size = int(torch.prod(torch.as_tensor(self.event_shape)))

# Flatten event and batch
z_flat = z.view(batch_size, event_size)
h_split_flat = [h_element.view(batch_size, event_size, -1) for h_element in h_split]
x_flat, log_det_flat = self.inverse_1d(z_flat, h_split_flat)

# Unflatten event and batch
x = x_flat.view(*batch_shape, *self.event_shape)
log_det = log_det_flat.view(*batch_shape)

return x, log_det


class DeepSigmoid(Combination):
Expand Down
7 changes: 5 additions & 2 deletions test/test_reconstruction_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,12 @@ def test_combination_basic(transformer_class: Transformer, batch_shape: Tuple, e
assert_valid_reconstruction(transformer, x, h)


@pytest.mark.parametrize('transformer_class', [DenseSigmoid, DeepDenseSigmoid])
@pytest.mark.parametrize('transformer_class', [
DenseSigmoid,
DeepDenseSigmoid
])
@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape'])
@pytest.mark.parametrize('event_shape', __test_constants['event_shape'])
def test_combination_vector_to_vector(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple):
transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape)
assert_valid_reconstruction(transformer, x, h, reconstruction_eps=1e-2)
assert_valid_reconstruction(transformer, x, h)

0 comments on commit 98151a4

Please sign in to comment.