Skip to content

Commit

Permalink
Fix residual flows when working with complex batch shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Aug 29, 2024
1 parent cdef039 commit 271e7f3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
11 changes: 5 additions & 6 deletions test/test_autograd_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \
LRSCoupling, LinearRQSCoupling
from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR
from torchflows.bijections.finite.residual.architectures import InvertibleResNet, ResFlow, ProximalResFlow
from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock
from torchflows.bijections.finite.residual.planar import Planar
from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock
from torchflows.bijections.finite.residual.radial import Radial
from torchflows.bijections.finite.residual.sylvester import Sylvester
from torchflows.utils import get_batch_shape
Expand Down Expand Up @@ -93,13 +95,10 @@ def test_masked_autoregressive(bijection_class: Bijection, batch_shape: Tuple, e
assert_valid_log_probability_gradient(bijection, x, context)


@pytest.mark.skip(reason="Computation takes too long")
@pytest.mark.parametrize('bijection_class', [
InvertibleResNetBlock,
ResFlowBlock,
Planar,
Radial,
Sylvester
InvertibleResNet,
ResFlow,
ProximalResFlow,
])
@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape'])
@pytest.mark.parametrize('event_shape', __test_constants['event_shape'])
Expand Down
22 changes: 15 additions & 7 deletions torchflows/bijections/finite/residual/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine
from torchflows.bijections.base import Bijection, BijectiveComposition
from torchflows.utils import get_batch_shape, unflatten_event, flatten_event
from torchflows.utils import get_batch_shape, unflatten_event, flatten_event, flatten_batch, unflatten_batch


class ResidualBijection(Bijection):
Expand All @@ -26,14 +26,18 @@ def forward(self,
context: torch.Tensor = None,
skip_log_det: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
batch_shape = get_batch_shape(x, self.event_shape)
z = x + unflatten_event(self.g(flatten_event(x, self.event_shape)), self.event_shape)
x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape)
g_flat = self.g(x_flat)
g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape)

z = x + g

if skip_log_det:
log_det = torch.full(size=batch_shape, fill_value=torch.nan)
else:
x_flat = flatten_event(x, self.event_shape).clone()
x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape)
x_flat.requires_grad_(True)
log_det = -self.log_det(x_flat, training=self.training)
log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape)

return z, log_det

Expand All @@ -45,14 +49,18 @@ def inverse(self,
batch_shape = get_batch_shape(z, self.event_shape)
x = z
for _ in range(n_iterations):
x = z - unflatten_event(self.g(flatten_event(x, self.event_shape)), self.event_shape)
x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape)
g_flat = self.g(x_flat)
g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape)

x = z - g

if skip_log_det:
log_det = torch.full(size=batch_shape, fill_value=torch.nan)
else:
x_flat = flatten_event(x, self.event_shape).clone()
x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape)
x_flat.requires_grad_(True)
log_det = -self.log_det(x_flat, training=self.training)
log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape)

return x, log_det

Expand Down
1 change: 1 addition & 0 deletions torchflows/bijections/finite/residual/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn

from potentials.utils import get_batch_shape
from torchflows.bijections.finite.residual.base import ResidualBijection
from torchflows.bijections.finite.residual.log_abs_det_estimators import log_det_power_series, log_det_roulette

Expand Down

0 comments on commit 271e7f3

Please sign in to comment.