From 98151a4b1252a9fc44fbdbb37997193d7a66ba9f Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 13 Oct 2023 14:09:22 +0200 Subject: [PATCH] Fix dense sigmoid transformer reconstruction and log determinant --- .../transformers/combination/sigmoid.py | 40 +++++++++++++++++-- test/test_reconstruction_transformers.py | 7 +++- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py index 02f2bda..6398596 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py @@ -259,7 +259,25 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch # x.shape == (*batch_shape, *event_shape) # h.shape == (*batch_shape, *event_shape, n_parameters) h_split = self.split_parameters(h) - z, log_det = self.forward_1d(x, h_split) + + n_event_dims = len(self.event_shape) + n_batch_dims = len(x.shape) - n_event_dims + batch_shape = x.shape[:n_batch_dims] + + batch_size = int(torch.prod(torch.as_tensor(batch_shape))) + event_size = int(torch.prod(torch.as_tensor(self.event_shape))) + + # Flatten event and batch + x_flat = x.view(batch_size, event_size) + h_split_flat = [h_element.view(batch_size, event_size, -1) for h_element in h_split] + + z_flat, log_det_flat = self.forward_1d(x_flat, h_split_flat) + + # Unflatten event and batch + z = z_flat.view(*batch_shape, *self.event_shape) + log_det = log_det_flat.view(*batch_shape) + # log_det = sum_except_batch(log_det, self.event_shape) + return z, log_det def inverse_1d(self, z, h): @@ -273,8 +291,24 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch # z.shape == (*batch_shape, *event_shape) # h.shape == (*batch_shape, *event_shape, n_parameters) h_split = self.split_parameters(h) - x, log_det = self.inverse_1d(z, h_split) - return z, log_det + + n_event_dims = len(self.event_shape) + n_batch_dims = len(z.shape) - n_event_dims + batch_shape = z.shape[:n_batch_dims] + + batch_size = int(torch.prod(torch.as_tensor(batch_shape))) + event_size = int(torch.prod(torch.as_tensor(self.event_shape))) + + # Flatten event and batch + z_flat = z.view(batch_size, event_size) + h_split_flat = [h_element.view(batch_size, event_size, -1) for h_element in h_split] + x_flat, log_det_flat = self.inverse_1d(z_flat, h_split_flat) + + # Unflatten event and batch + x = x_flat.view(*batch_shape, *self.event_shape) + log_det = log_det_flat.view(*batch_shape) + + return x, log_det class DeepSigmoid(Combination): diff --git a/test/test_reconstruction_transformers.py b/test/test_reconstruction_transformers.py index 83994af..3cdaf43 100644 --- a/test/test_reconstruction_transformers.py +++ b/test/test_reconstruction_transformers.py @@ -103,9 +103,12 @@ def test_combination_basic(transformer_class: Transformer, batch_shape: Tuple, e assert_valid_reconstruction(transformer, x, h) -@pytest.mark.parametrize('transformer_class', [DenseSigmoid, DeepDenseSigmoid]) +@pytest.mark.parametrize('transformer_class', [ + DenseSigmoid, + DeepDenseSigmoid +]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) def test_combination_vector_to_vector(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) - assert_valid_reconstruction(transformer, x, h, reconstruction_eps=1e-2) + assert_valid_reconstruction(transformer, x, h)