From 271e7f33bba3ddcc2cfe649634ee82bf05bb6e54 Mon Sep 17 00:00:00 2001
From: David Nabergoj <davidnabergoj4@gmail.com>
Date: Thu, 29 Aug 2024 21:03:50 +0200
Subject: [PATCH] Fix residual flows when working with complex batch shapes

---
 test/test_autograd_bijections.py              | 11 +++++-----
 torchflows/bijections/finite/residual/base.py | 22 +++++++++++++------
 .../bijections/finite/residual/iterative.py   |  1 +
 3 files changed, 21 insertions(+), 13 deletions(-)

diff --git a/test/test_autograd_bijections.py b/test/test_autograd_bijections.py
index e609a6e..6b4069f 100644
--- a/test/test_autograd_bijections.py
+++ b/test/test_autograd_bijections.py
@@ -10,8 +10,10 @@
 from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \
     LRSCoupling, LinearRQSCoupling
 from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR
+from torchflows.bijections.finite.residual.architectures import InvertibleResNet, ResFlow, ProximalResFlow
 from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock
 from torchflows.bijections.finite.residual.planar import Planar
+from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock
 from torchflows.bijections.finite.residual.radial import Radial
 from torchflows.bijections.finite.residual.sylvester import Sylvester
 from torchflows.utils import get_batch_shape
@@ -93,13 +95,10 @@ def test_masked_autoregressive(bijection_class: Bijection, batch_shape: Tuple, e
     assert_valid_log_probability_gradient(bijection, x, context)
 
 
-@pytest.mark.skip(reason="Computation takes too long")
 @pytest.mark.parametrize('bijection_class', [
-    InvertibleResNetBlock,
-    ResFlowBlock,
-    Planar,
-    Radial,
-    Sylvester
+    InvertibleResNet,
+    ResFlow,
+    ProximalResFlow,
 ])
 @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape'])
 @pytest.mark.parametrize('event_shape', __test_constants['event_shape'])
diff --git a/torchflows/bijections/finite/residual/base.py b/torchflows/bijections/finite/residual/base.py
index 2598a62..55b0e61 100644
--- a/torchflows/bijections/finite/residual/base.py
+++ b/torchflows/bijections/finite/residual/base.py
@@ -4,7 +4,7 @@
 
 from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine
 from torchflows.bijections.base import Bijection, BijectiveComposition
-from torchflows.utils import get_batch_shape, unflatten_event, flatten_event
+from torchflows.utils import get_batch_shape, unflatten_event, flatten_event, flatten_batch, unflatten_batch
 
 
 class ResidualBijection(Bijection):
@@ -26,14 +26,18 @@ def forward(self,
                 context: torch.Tensor = None,
                 skip_log_det: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
         batch_shape = get_batch_shape(x, self.event_shape)
-        z = x + unflatten_event(self.g(flatten_event(x, self.event_shape)), self.event_shape)
+        x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape)
+        g_flat = self.g(x_flat)
+        g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape)
+
+        z = x + g
 
         if skip_log_det:
             log_det = torch.full(size=batch_shape, fill_value=torch.nan)
         else:
-            x_flat = flatten_event(x, self.event_shape).clone()
+            x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape)
             x_flat.requires_grad_(True)
-            log_det = -self.log_det(x_flat, training=self.training)
+            log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape)
 
         return z, log_det
 
@@ -45,14 +49,18 @@ def inverse(self,
         batch_shape = get_batch_shape(z, self.event_shape)
         x = z
         for _ in range(n_iterations):
-            x = z - unflatten_event(self.g(flatten_event(x, self.event_shape)), self.event_shape)
+            x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape)
+            g_flat = self.g(x_flat)
+            g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape)
+
+            x = z - g
 
         if skip_log_det:
             log_det = torch.full(size=batch_shape, fill_value=torch.nan)
         else:
-            x_flat = flatten_event(x, self.event_shape).clone()
+            x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape)
             x_flat.requires_grad_(True)
-            log_det = -self.log_det(x_flat, training=self.training)
+            log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape)
 
         return x, log_det
 
diff --git a/torchflows/bijections/finite/residual/iterative.py b/torchflows/bijections/finite/residual/iterative.py
index 9e7bbd2..b294465 100644
--- a/torchflows/bijections/finite/residual/iterative.py
+++ b/torchflows/bijections/finite/residual/iterative.py
@@ -4,6 +4,7 @@
 import torch
 import torch.nn as nn
 
+from potentials.utils import get_batch_shape
 from torchflows.bijections.finite.residual.base import ResidualBijection
 from torchflows.bijections.finite.residual.log_abs_det_estimators import log_det_power_series, log_det_roulette