Skip to content

Commit

Permalink
Add multiscale bijections and supporting classes
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Jun 23, 2024
1 parent 0b883d9 commit 416525e
Show file tree
Hide file tree
Showing 14 changed files with 291 additions and 261 deletions.
2 changes: 2 additions & 0 deletions normalizing_flows/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@
Radial,
Sylvester
)

from normalizing_flows.bijections.finite.multiscale.architectures import MultiscaleRealNVP
1 change: 1 addition & 0 deletions normalizing_flows/bijections/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self,
self.event_shape = event_shape
self.n_dim = int(torch.prod(torch.as_tensor(event_shape)))
self.context_shape = context_shape
self.transformed_shape = self.event_shape # Overwritten in multiscale flows TODO make into property

def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down
120 changes: 6 additions & 114 deletions normalizing_flows/bijections/finite/autoregressive/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,6 @@
from normalizing_flows.bijections.finite.linear import ReversePermutation


def make_layers(base_bijection: Type[
Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]],
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
image_coupling: bool = False):
if image_coupling:
if len(event_shape) == 2:
bijections = make_image_layers_single_channel(base_bijection, event_shape, n_layers)
elif len(event_shape) == 3:
bijections = make_image_layers_multichannel(base_bijection, event_shape, n_layers)
else:
raise ValueError
else:
bijections = make_basic_layers(base_bijection, event_shape, n_layers, edge_list)
return bijections


def make_basic_layers(base_bijection: Type[
Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]],
event_shape,
Expand All @@ -56,100 +38,15 @@ def make_basic_layers(base_bijection: Type[
return bijections


def make_image_layers_single_channel(base_bijection: Type[
Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]],
event_shape,
n_layers: int = 2,
checkerboard_resolution: int = 2):
"""
Returns a list of bijections for transformations of images with a single channel.
Each layer consists of two coupling transforms:
1. checkerboard,
2. checkerboard_inverted.
"""
if len(event_shape) != 2:
raise ValueError("Single-channel image transformation are only possible for inputs with two axes.")

bijections = [ElementwiseAffine(event_shape=event_shape)]
for _ in range(n_layers):
bijections.append(base_bijection(
event_shape=event_shape,
coupling_kwargs={
'coupling_type': 'checkerboard',
'resolution': checkerboard_resolution,
}
))
bijections.append(base_bijection(
event_shape=event_shape,
coupling_kwargs={
'coupling_type': 'checkerboard_inverted',
'resolution': checkerboard_resolution,
}
))
bijections.append(ElementwiseAffine(event_shape=event_shape))
return bijections


def make_image_layers_multichannel(base_bijection: Type[
Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]],
event_shape,
n_layers: int = 2,
checkerboard_resolution: int = 2):
"""
Returns a list of bijections for transformations of images with multiple channels.
Each layer consists of four coupling transforms:
1. checkerboard,
2. channel_wise,
3. checkerboard_inverted,
4. channel_wise_inverted.
"""
if len(event_shape) != 3:
raise ValueError("Multichannel image transformation are only possible for inputs with three axes.")

bijections = [ElementwiseAffine(event_shape=event_shape)]
for _ in range(n_layers):
bijections.append(base_bijection(
event_shape=event_shape,
coupling_kwargs={
'coupling_type': 'checkerboard',
'resolution': checkerboard_resolution,
}
))
bijections.append(base_bijection(
event_shape=event_shape,
coupling_kwargs={
'coupling_type': 'channel_wise'
}
))
bijections.append(base_bijection(
event_shape=event_shape,
coupling_kwargs={
'coupling_type': 'checkerboard_inverted',
'resolution': checkerboard_resolution,
}
))
bijections.append(base_bijection(
event_shape=event_shape,
coupling_kwargs={
'coupling_type': 'channel_wise_inverted'
}
))
bijections.append(ElementwiseAffine(event_shape=event_shape))
return bijections


class NICE(BijectiveComposition):
def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
image_coupling: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_layers(ShiftCoupling, event_shape, n_layers, edge_list, image_coupling)
bijections = make_basic_layers(ShiftCoupling, event_shape, n_layers, edge_list)
super().__init__(event_shape, bijections, **kwargs)


Expand All @@ -158,11 +55,10 @@ def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
image_coupling: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_layers(AffineCoupling, event_shape, n_layers, edge_list, image_coupling)
bijections = make_basic_layers(AffineCoupling, event_shape, n_layers, edge_list)
super().__init__(event_shape, bijections, **kwargs)


Expand All @@ -171,11 +67,10 @@ def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
image_coupling: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_layers(InverseAffineCoupling, event_shape, n_layers, edge_list, image_coupling)
bijections = make_basic_layers(InverseAffineCoupling, event_shape, n_layers, edge_list)
super().__init__(event_shape, bijections, **kwargs)


Expand Down Expand Up @@ -204,11 +99,10 @@ def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
image_coupling: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_layers(RQSCoupling, event_shape, n_layers, edge_list, image_coupling)
bijections = make_basic_layers(RQSCoupling, event_shape, n_layers, edge_list)
super().__init__(event_shape, bijections, **kwargs)


Expand All @@ -229,11 +123,10 @@ def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
image_coupling: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_layers(LRSCoupling, event_shape, n_layers, edge_list, image_coupling)
bijections = make_basic_layers(LRSCoupling, event_shape, n_layers, edge_list)
super().__init__(event_shape, bijections, **kwargs)


Expand All @@ -258,11 +151,10 @@ def __init__(self,
event_shape,
n_layers: int = 2,
edge_list: List[Tuple[int, int]] = None,
image_coupling: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_layers(DSCoupling, event_shape, n_layers, edge_list, image_coupling)
bijections = make_basic_layers(DSCoupling, event_shape, n_layers, edge_list)
super().__init__(event_shape, bijections, **kwargs)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,51 +73,6 @@ def __init__(self, event_shape):
super().__init__(event_shape, mask=torch.less(torch.arange(event_size).view(*event_shape), event_size // 2))


class Checkerboard(Coupling):
"""
Checkerboard coupling for image data.
"""

def __init__(self, event_shape, resolution: int = 2, invert: bool = False):
"""
:param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal
and a power of two.
:param resolution: resolution of the checkerboard along one axis - the number of squares. Must be a power of two
and smaller than image width.
:param invert: invert the checkerboard mask.
"""
height, width = event_shape[-2:]
assert width % resolution == 0
square_side_length = width // resolution
assert resolution % 2 == 0
half_resolution = resolution // 2
a = torch.tensor([[1, 0] * half_resolution, [0, 1] * half_resolution] * half_resolution)
mask = torch.kron(a, torch.ones((square_side_length, square_side_length)))
mask = mask.bool()
if invert:
mask = ~mask
super().__init__(event_shape, mask)


class ChannelWiseHalfSplit(Coupling):
"""
Channel-wise coupling for image data.
"""

def __init__(self, event_shape, invert: bool = False):
"""
:param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal
and a power of two.
:param invert: invert the checkerboard mask.
"""
n_channels, height, width = event_shape
mask = torch.as_tensor(torch.arange(start=0, end=n_channels) < (n_channels // 2))
mask = mask[:, None, None].repeat(1, height, width)
if invert:
mask = ~mask
super().__init__(event_shape, mask)


def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None, coupling_type: str = 'half_split', **kwargs):
"""
Expand All @@ -129,16 +84,7 @@ def make_coupling(event_shape, edge_list: List[Tuple[int, int]] = None, coupling
"""
if edge_list is not None:
return GraphicalCoupling(event_shape, edge_list)
elif coupling_type == 'half_split':
return HalfSplit(event_shape)
else:
if coupling_type == 'half_split':
return HalfSplit(event_shape)
elif coupling_type == 'checkerboard':
return Checkerboard(event_shape, invert=False, **kwargs)
elif coupling_type == 'checkerboard_inverted':
return Checkerboard(event_shape, invert=True, **kwargs)
elif coupling_type == 'channel_wise':
return ChannelWiseHalfSplit(event_shape, invert=False)
elif coupling_type == 'channel_wise_inverted':
return ChannelWiseHalfSplit(event_shape, invert=True)
else:
raise ValueError
raise ValueError
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Union
from typing import Tuple, Union, Type

import torch
import torch.nn as nn
Expand Down
83 changes: 0 additions & 83 deletions normalizing_flows/bijections/finite/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,89 +14,6 @@
from normalizing_flows.utils import get_batch_shape, flatten_event, unflatten_event, flatten_batch, unflatten_batch


class Squeeze(Bijection):
"""
Squeeze a batch of tensors with shape (*batch_shape, channels, height, width) into shape
(*batch_shape, 4 * channels, height / 2, width / 2).
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
# Check shape length
if len(event_shape) != 3:
raise ValueError(f"Event shape must have three components, but got {len(event_shape)}")
# Check that height and width are divisible by two
if event_shape[1] % 2 != 0:
raise ValueError(f"Event dimension 1 must be divisible by 2, but got {event_shape[1]}")
if event_shape[2] % 2 != 0:
raise ValueError(f"Event dimension 2 must be divisible by 2, but got {event_shape[2]}")
super().__init__(event_shape, **kwargs)
c, h, w = event_shape
self.squeezed_event_shape = torch.Size((4 * c, h // 2, w // 2))

def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Squeeze tensor with shape (*batch_shape, channels, height, width) into tensor with shape
(*batch_shape, 4 * channels, height // 2, width // 2).
"""
batch_shape = get_batch_shape(x, self.event_shape)
log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype)

channels, height, width = x.shape[-3:]
assert height % 2 == 0
assert width % 2 == 0
n_rows = height // 2
n_cols = width // 2
n_squares = n_rows * n_cols

square_mask = torch.kron(
torch.arange(n_squares).view(n_rows, n_cols),
torch.ones(2, 2)
)
channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1)

# out = torch.zeros(size=(*batch_shape, self.squeezed_event_shape), device=x.device, dtype=x.dtype)
out = torch.empty(size=(*batch_shape, 4 * channels, height // 2, width // 2), device=x.device, dtype=x.dtype)

channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1)
square_mask = square_mask.repeat(*batch_shape, channels, 1, 1)
for i in range(n_squares):
out[channel_mask == i] = x[square_mask == i]

return out, log_det

def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Squeeze tensor with shape (*batch_shape, 4 * channels, height // 2, width // 2) into tensor with shape
(*batch_shape, channels, height, width).
"""
batch_shape = get_batch_shape(z, self.squeezed_event_shape)
log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype)

four_channels, half_height, half_width = z.shape[-3:]
assert four_channels % 4 == 0
width = 2 * half_width
height = 2 * half_height
channels = four_channels // 4

n_rows = height // 2
n_cols = width // 2
n_squares = n_rows * n_cols

square_mask = torch.kron(
torch.arange(n_squares).view(n_rows, n_cols),
torch.ones(2, 2)
)
channel_mask = torch.arange(n_rows * n_cols).view(n_rows, n_cols)[None].repeat(4 * channels, 1, 1)
out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype)

channel_mask = channel_mask.repeat(*batch_shape, 1, 1, 1)
square_mask = square_mask.repeat(*batch_shape, channels, 1, 1)
for i in range(n_squares):
out[square_mask == i] = z[channel_mask == i]

return out, log_det


class LinearBijection(Bijection):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], matrix: InvertibleMatrix):
super().__init__(event_shape)
Expand Down
Empty file.
Loading

0 comments on commit 416525e

Please sign in to comment.