From 271e7f33bba3ddcc2cfe649634ee82bf05bb6e54 Mon Sep 17 00:00:00 2001 From: David Nabergoj <davidnabergoj4@gmail.com> Date: Thu, 29 Aug 2024 21:03:50 +0200 Subject: [PATCH] Fix residual flows when working with complex batch shapes --- test/test_autograd_bijections.py | 11 +++++----- torchflows/bijections/finite/residual/base.py | 22 +++++++++++++------ .../bijections/finite/residual/iterative.py | 1 + 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/test/test_autograd_bijections.py b/test/test_autograd_bijections.py index e609a6e..6b4069f 100644 --- a/test/test_autograd_bijections.py +++ b/test/test_autograd_bijections.py @@ -10,8 +10,10 @@ from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \ LRSCoupling, LinearRQSCoupling from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR +from torchflows.bijections.finite.residual.architectures import InvertibleResNet, ResFlow, ProximalResFlow from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock from torchflows.bijections.finite.residual.planar import Planar +from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock from torchflows.bijections.finite.residual.radial import Radial from torchflows.bijections.finite.residual.sylvester import Sylvester from torchflows.utils import get_batch_shape @@ -93,13 +95,10 @@ def test_masked_autoregressive(bijection_class: Bijection, batch_shape: Tuple, e assert_valid_log_probability_gradient(bijection, x, context) -@pytest.mark.skip(reason="Computation takes too long") @pytest.mark.parametrize('bijection_class', [ - InvertibleResNetBlock, - ResFlowBlock, - Planar, - Radial, - Sylvester + InvertibleResNet, + ResFlow, + ProximalResFlow, ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) diff --git a/torchflows/bijections/finite/residual/base.py b/torchflows/bijections/finite/residual/base.py index 2598a62..55b0e61 100644 --- a/torchflows/bijections/finite/residual/base.py +++ b/torchflows/bijections/finite/residual/base.py @@ -4,7 +4,7 @@ from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine from torchflows.bijections.base import Bijection, BijectiveComposition -from torchflows.utils import get_batch_shape, unflatten_event, flatten_event +from torchflows.utils import get_batch_shape, unflatten_event, flatten_event, flatten_batch, unflatten_batch class ResidualBijection(Bijection): @@ -26,14 +26,18 @@ def forward(self, context: torch.Tensor = None, skip_log_det: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(x, self.event_shape) - z = x + unflatten_event(self.g(flatten_event(x, self.event_shape)), self.event_shape) + x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape) + g_flat = self.g(x_flat) + g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape) + + z = x + g if skip_log_det: log_det = torch.full(size=batch_shape, fill_value=torch.nan) else: - x_flat = flatten_event(x, self.event_shape).clone() + x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape) x_flat.requires_grad_(True) - log_det = -self.log_det(x_flat, training=self.training) + log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) return z, log_det @@ -45,14 +49,18 @@ def inverse(self, batch_shape = get_batch_shape(z, self.event_shape) x = z for _ in range(n_iterations): - x = z - unflatten_event(self.g(flatten_event(x, self.event_shape)), self.event_shape) + x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape) + g_flat = self.g(x_flat) + g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape) + + x = z - g if skip_log_det: log_det = torch.full(size=batch_shape, fill_value=torch.nan) else: - x_flat = flatten_event(x, self.event_shape).clone() + x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape) x_flat.requires_grad_(True) - log_det = -self.log_det(x_flat, training=self.training) + log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) return x, log_det diff --git a/torchflows/bijections/finite/residual/iterative.py b/torchflows/bijections/finite/residual/iterative.py index 9e7bbd2..b294465 100644 --- a/torchflows/bijections/finite/residual/iterative.py +++ b/torchflows/bijections/finite/residual/iterative.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from potentials.utils import get_batch_shape from torchflows.bijections.finite.residual.base import ResidualBijection from torchflows.bijections.finite.residual.log_abs_det_estimators import log_det_power_series, log_det_roulette