Skip to content

Commit

Permalink
Add convolutional residual flows and refactor residual event shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Sep 1, 2024
1 parent 3bc85b0 commit f3b22ca
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 75 deletions.
21 changes: 20 additions & 1 deletion test/test_convolutional_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import pytest
from test.constants import __test_constants
from torchflows.bijections.finite.residual.architectures import ConvolutionalInvertibleResNet, ConvolutionalResFlow


@pytest.mark.parametrize('architecture_class', [
Expand Down Expand Up @@ -52,7 +53,25 @@ def test_continuous(architecture_class, image_shape):
xr, ldi = bijection.inverse(z)
assert x.shape == xr.shape
assert ldf.shape == ldi.shape
assert torch.allclose(x, xr, atol=__test_constants['data_atol']), f'"{(x-xr).abs().max()}"'
assert torch.allclose(x, xr, atol=__test_constants['data_atol']), f'"{(x - xr).abs().max()}"'
assert torch.allclose(ldf, -ldi, atol=__test_constants['log_det_atol']) # 1e-2


@pytest.mark.skip('Unsupported/failing')
@pytest.mark.parametrize('architecture_class', [
ConvolutionalInvertibleResNet,
ConvolutionalResFlow
])
@pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)])
def test_residual(architecture_class, image_shape):
torch.manual_seed(0)
x = torch.randn(size=(5, *image_shape))
bijection = architecture_class(image_shape)
z, ldf = bijection.forward(x)
xr, ldi = bijection.inverse(z)
assert x.shape == xr.shape
assert ldf.shape == ldi.shape
assert torch.allclose(x, xr, atol=__test_constants['data_atol']), f'"{(x - xr).abs().max()}"'
assert torch.allclose(ldf, -ldi, atol=__test_constants['log_det_atol']) # 1e-2


Expand Down
9 changes: 8 additions & 1 deletion test/test_stochastic_log_det_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_power_series_estimator(n_iterations, n_hutchinson_samples):
torch.manual_seed(0)
x = torch.randn(size=(n_data, n_dim))
g_value, log_det_f_estimated = log_det_power_series(
(n_dim,),
g,
x,
training=False,
Expand Down Expand Up @@ -76,7 +77,13 @@ def test_roulette_estimator(p):

torch.manual_seed(0)
x = torch.randn(size=(n_data, n_dim))
g_value, log_det_f = log_det_roulette(g, x, training=False, p=p)
g_value, log_det_f = log_det_roulette(
(n_dim,),
g,
x,
training=False,
p=p
)
log_det_f_true = test_data.log_det_jac_f(x).ravel()

print(f'{log_det_f = }')
Expand Down
4 changes: 3 additions & 1 deletion torchflows/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
InvertibleResNet,
PlanarFlow,
RadialFlow,
SylvesterFlow
SylvesterFlow,
ConvolutionalInvertibleResNet,
ConvolutionalResFlow
)

from torchflows.bijections.finite.multiscale.architectures import (
Expand Down
27 changes: 26 additions & 1 deletion torchflows/bijections/finite/residual/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import torch

from torchflows.bijections.finite.residual.base import ResidualArchitecture
from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock
from torchflows.bijections.finite.residual.iterative import (
InvertibleResNetBlock,
ResFlowBlock,
ConvolutionalInvertibleResNetBlock,
ConvolutionalResFlowBlock
)
from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock
from torchflows.bijections.finite.residual.planar import Planar
from torchflows.bijections.finite.residual.radial import Radial
Expand All @@ -20,6 +25,16 @@ def __init__(self, event_shape, **kwargs):
super().__init__(event_shape, InvertibleResNetBlock, **kwargs)


class ConvolutionalInvertibleResNet(ResidualArchitecture):
"""Convolutional variant of i-ResNet.
Reference: Behrmann et al. "Invertible Residual Networks" (2019); https://arxiv.org/abs/1811.00995.
"""

def __init__(self, event_shape, **kwargs):
super().__init__(event_shape, ConvolutionalInvertibleResNetBlock, **kwargs)


class ResFlow(ResidualArchitecture):
"""Residual flow (ResFlow) architecture.
Expand All @@ -30,6 +45,16 @@ def __init__(self, event_shape, **kwargs):
super().__init__(event_shape, ResFlowBlock, **kwargs)


class ConvolutionalResFlow(ResidualArchitecture):
"""Convolutional variant of ResFlow.
Reference: Chen et al. "Residual Flows for Invertible Generative Modeling" (2020); https://arxiv.org/abs/1906.02735.
"""

def __init__(self, event_shape, **kwargs):
super().__init__(event_shape, ConvolutionalResFlowBlock, **kwargs)


class ProximalResFlow(ResidualArchitecture):
"""Proximal residual flow architecture.
Expand Down
26 changes: 12 additions & 14 deletions torchflows/bijections/finite/residual/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@ def __init__(self,
self.invert()


class ResidualBijection(Bijection):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
"""
g maps from (*batch_shape, n_event_dims) to (*batch_shape, n_event_dims)
class IterativeResidualBijection(Bijection):
"""
g maps from (*batch_shape, *event_shape) to (*batch_shape, *event_shape)
"""

:param event_shape:
"""
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
super().__init__(event_shape)
self.g: callable = None

Expand All @@ -36,16 +34,16 @@ def forward(self,
context: torch.Tensor = None,
skip_log_det: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
batch_shape = get_batch_shape(x, self.event_shape)
x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape)
x_flat = flatten_batch(x, batch_shape)
g_flat = self.g(x_flat)
g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape)
g = unflatten_batch(g_flat, batch_shape)

z = x + g

if skip_log_det:
log_det = torch.full(size=batch_shape, fill_value=torch.nan)
else:
x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape)
x_flat = flatten_batch(x.clone(), batch_shape)
x_flat.requires_grad_(True)
log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape)

Expand All @@ -59,16 +57,16 @@ def inverse(self,
batch_shape = get_batch_shape(z, self.event_shape)
x = z
for _ in range(n_iterations):
x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape)
x_flat = flatten_batch(x, batch_shape)
g_flat = self.g(x_flat)
g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape)
g = unflatten_batch(g_flat, batch_shape)

x = z - g

if skip_log_det:
log_det = torch.full(size=batch_shape, fill_value=torch.nan)
else:
x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape)
x_flat = flatten_batch(x.clone(), batch_shape)
x_flat.requires_grad_(True)
log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape)

Expand All @@ -78,7 +76,7 @@ def inverse(self,
class ResidualArchitecture(BijectiveComposition):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
layer_class: Type[Union[ResidualBijection, ClassicResidualBijection]],
layer_class: Type[Union[IterativeResidualBijection, ClassicResidualBijection]],
n_layers: int = 2,
layer_kwargs: dict = None,
**kwargs):
Expand Down
133 changes: 104 additions & 29 deletions torchflows/bijections/finite/residual/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,16 @@
import torch
import torch.nn as nn

from torchflows.bijections.finite.residual.base import ResidualBijection
from torchflows.bijections.finite.residual.base import IterativeResidualBijection
from torchflows.bijections.finite.residual.log_abs_det_estimators import log_det_power_series, log_det_roulette
from torchflows.utils import get_batch_shape


class SpectralLinear(nn.Module):
# https://arxiv.org/pdf/1811.00995.pdf

def __init__(self, n_inputs: int, n_outputs: int, c: float = 0.7, n_iterations: int = 5):
class SpectralMatrix(nn.Module):
def __init__(self, shape: Tuple[int, int], c: float = 0.7, n_iterations: int = 5):
super().__init__()
self.data = torch.randn(size=shape)
self.c = c
self.n_inputs = n_inputs
self.w = torch.randn(n_outputs, n_inputs)
self.bias = nn.Parameter(torch.randn(n_outputs))
self.n_iterations = n_iterations

@torch.no_grad()
Expand All @@ -25,7 +22,7 @@ def power_iteration(self, w):
# Spectral Normalization for Generative Adversarial Networks - Miyato et al. - 2018

# Get maximum singular value of rectangular matrix w
u = torch.randn(self.n_inputs, 1)
u = torch.randn(self.data.shape[1], 1)
v = None

w = w.T
Expand All @@ -40,50 +37,128 @@ def power_iteration(self, w):
factor = u.T @ w @ v
return factor

@property
def normalized_mat(self):
def normalized(self):
# Estimate sigma
sigma = self.power_iteration(self.w)
sigma = self.power_iteration(self.data)
# ratio = self.c / sigma
# return self.w * (ratio ** (ratio < 1))
return self.w / sigma
return self.data / sigma


class SpectralLinear(nn.Module):
# https://arxiv.org/pdf/1811.00995.pdf

def __init__(self, n_inputs: int, n_outputs: int, **kwargs):
super().__init__()
self.w = SpectralMatrix((n_outputs, n_inputs), **kwargs)
self.bias = nn.Parameter(torch.randn(n_outputs))

def forward(self, x):
return torch.nn.functional.linear(x, self.normalized_mat, self.bias)
return torch.nn.functional.linear(x, self.w.normalized(), self.bias)


class SpectralNeuralNetwork(nn.Sequential):
def __init__(self, n_dim: int, n_hidden: int = None, n_hidden_layers: int = 1, **kwargs):
class SpectralConv2d(nn.Module):
def __init__(self, n_channels: int, kernel_shape: Tuple[int, int] = (3, 3), **kwargs):
super().__init__()
self.n_channels = n_channels
self.kernel_shape = kernel_shape
self.weight = SpectralMatrix((n_channels * kernel_shape[0], n_channels * kernel_shape[1]), **kwargs)
self.bias = nn.Parameter(torch.randn(n_channels))

def forward(self, x):
w = self.weight.normalized().view(self.n_channels, self.n_channels, *self.kernel_shape)
return torch.conv2d(x, w, self.bias, padding='same')


class SpectralNeuralNetwork(nn.Module):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
n_hidden: int = None,
n_hidden_layers: int = 1,
**kwargs):
self.event_shape = event_shape
event_size = int(torch.prod(torch.as_tensor(event_shape)))
if n_hidden is None:
n_hidden = int(3 * max(math.log(n_dim), 4))
n_hidden = int(3 * max(math.log(event_size), 4))

layers = []
if n_hidden_layers == 0:
layers = [SpectralLinear(n_dim, n_dim, **kwargs)]
layers = [SpectralLinear(event_size, event_size, **kwargs)]
else:
layers.append(SpectralLinear(n_dim, n_hidden, **kwargs))
layers = [SpectralLinear(event_size, n_hidden, **kwargs)]
for _ in range(n_hidden):
layers.append(nn.Tanh())
layers.append(SpectralLinear(n_hidden, n_hidden, **kwargs))
layers.pop(-1)
layers.append(SpectralLinear(n_hidden, n_dim, **kwargs))
layers.append(SpectralLinear(n_hidden, event_size, **kwargs))
super().__init__()
self.layers = nn.ModuleList(layers)

def forward(self, x):
batch_shape = get_batch_shape(x, self.event_shape)
x_flat = x.view(*batch_shape, -1)
for layer in self.layers:
x_flat = layer(x_flat)
x = x_flat.view_as(x)
return x


class SpectralCNN(nn.Sequential):
def __init__(self, n_channels: int, n_layers: int = 2, **kwargs):
layers = []
for _ in range(n_layers):
layers.append(SpectralConv2d(n_channels, **kwargs))
layers.append(nn.Tanh())
layers.pop(-1)
super().__init__(*layers)


class InvertibleResNetBlock(ResidualBijection):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], context_shape=None, **kwargs):
class InvertibleResNetBlock(IterativeResidualBijection):
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
context_shape: Union[torch.Size, Tuple[int, ...]] = None,
g: nn.Module = None,
**kwargs):
# TODO add context
super().__init__(event_shape)
self.g = SpectralNeuralNetwork(n_dim=self.n_dim, **kwargs)
if g is None:
g = SpectralNeuralNetwork(event_shape, **kwargs)
self.g = g

def log_det(self, x: torch.Tensor, **kwargs):
return log_det_power_series(self.g, x, n_iterations=2, **kwargs)[1]
return log_det_power_series(self.event_shape, self.g, x, n_iterations=2, **kwargs)[1]


class ResFlowBlock(InvertibleResNetBlock):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], context_shape=None, p: float = 0.5, **kwargs):
class ConvolutionalInvertibleResNetBlock(InvertibleResNetBlock):
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
context_shape: Union[torch.Size, Tuple[int, ...]] = None,
**kwargs):
# TODO add context
super().__init__(event_shape, g=SpectralCNN(n_channels=event_shape[0]), **kwargs)


class ResFlowBlock(IterativeResidualBijection):
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
context_shape: Union[torch.Size, Tuple[int, ...]] = None,
g: nn.Module = None,
p: float = 0.5,
**kwargs):
# TODO add context
super().__init__(event_shape)
if g is None:
g = SpectralNeuralNetwork(event_shape, **kwargs)
self.g = g
self.p = p
super().__init__(event_shape, **kwargs)

def log_det(self, x: torch.Tensor, **kwargs):
return log_det_roulette(self.g, x, p=self.p, **kwargs)[1]
return log_det_roulette(self.event_shape, self.g, x, p=self.p, **kwargs)[1]


class ConvolutionalResFlowBlock(ResFlowBlock):
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
context_shape: Union[torch.Size, Tuple[int, ...]] = None,
**kwargs):
# TODO add context
super().__init__(event_shape, g=SpectralCNN(n_channels=event_shape[0]), **kwargs)
Loading

0 comments on commit f3b22ca

Please sign in to comment.