Skip to content

Commit

Permalink
Modify default multiscale bijection
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Jul 8, 2024
1 parent ee6afea commit 3d2620b
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def make_image_layers_non_factored(event_shape,
MultiscaleBijection(
input_event_shape=bijections[-1].transformed_shape,
transformer_class=transformer_class,
n_checkerboard_layers=4,
n_checkerboard_layers=0,
squeeze_layer=False,
n_channel_wise_layers=0,
n_channel_wise_layers=2,
**kwargs
)
)
Expand Down
6 changes: 3 additions & 3 deletions normalizing_flows/bijections/finite/multiscale/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self,
self.network = ConvNet(
input_shape=input_event_shape,
n_outputs=self.n_transformer_parameters,
kernels=kernels,
kernels=kernels
)

def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:
Expand Down Expand Up @@ -260,8 +260,8 @@ class MultiscaleBijection(BijectiveComposition):
def __init__(self,
input_event_shape,
transformer_class: Type[TensorTransformer],
n_checkerboard_layers: int = 3,
n_channel_wise_layers: int = 3,
n_checkerboard_layers: int = 2,
n_channel_wise_layers: int = 2,
use_squeeze_layer: bool = True,
use_resnet: bool = False,
**kwargs):
Expand Down
75 changes: 63 additions & 12 deletions normalizing_flows/neural_networks/convnet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,46 @@
import math
from typing import Tuple

import torch
import torch.nn as nn


class ConvModifier(nn.Module):
"""
Convolutional layer that transforms an image with size (c, h, w) into an image with size (4, 32, 32).
"""

def __init__(self,
image_shape,
c_target: int = 4,
h_target: int = 32,
w_target: int = 32):
super().__init__()
c, h, w = image_shape
if h >= h_target:
kernel_height = h - h_target + 1
padding_height = 0
else:
kernel_height = 1 if (h_target - h) % 2 == 0 else 2
padding_height = ((h_target - h) + kernel_height - 1) // 2
if w >= w_target:
kernel_width = w - w_target + 1
padding_width = 0
else:
kernel_width = 1 if (w_target - w) % 2 == 0 else 2
padding_width = ((w_target - w) + kernel_width - 1) // 2
self.conv = nn.Conv2d(
in_channels=c,
out_channels=c_target,
kernel_size=(kernel_height, kernel_width),
padding=(padding_height, padding_width)
)
self.output_shape = (c_target, h_target, w_target)

def forward(self, x):
return self.conv(x)


class ConvNet(nn.Module):
class ConvNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, input_height, input_width, use_pooling: bool = True):
Expand All @@ -27,20 +64,21 @@ def __init__(self, input_shape, n_outputs: int, kernels: Tuple[int, ...] = None)
:param n_outputs:
"""
super().__init__()
channels, height, width = input_shape

if kernels is None:
kernels = (64, 64, 32, 4)
kernels = (8, 8, 4)
else:
assert len(kernels) >= 1

reducer = ConvModifier(input_shape)

blocks = [
self.ConvNetBlock(
in_channels=channels,
in_channels=reducer.output_shape[0],
out_channels=kernels[0],
input_height=height,
input_width=width,
use_pooling=min(height, width) >= 2
input_height=reducer.output_shape[1],
input_width=reducer.output_shape[2],
use_pooling=min(reducer.output_shape[1], reducer.output_shape[2]) >= 2
)
]
for i in range(len(kernels) - 1):
Expand All @@ -53,24 +91,37 @@ def __init__(self, input_shape, n_outputs: int, kernels: Tuple[int, ...] = None)
use_pooling=min(blocks[i].output_shape[1], blocks[i].output_shape[2]) >= 2
)
)
self.blocks = nn.ModuleList(blocks)

self.blocks = nn.ModuleList([reducer] + blocks)

hidden_size_sqrt: int = 4
hidden_size = hidden_size_sqrt ** 2
self.blocks.append(
ConvModifier(
image_shape=blocks[-1].output_shape,
c_target=1,
h_target=hidden_size_sqrt,
w_target=hidden_size_sqrt
)
)
self.linear = nn.Linear(
in_features=int(torch.prod(torch.as_tensor(self.blocks[-1].output_shape))),
in_features=hidden_size,
out_features=n_outputs
)

def forward(self, x):
batch_shape = x.shape[:-3]
for block in self.blocks:
x = block(x)
x = x.flatten(start_dim=1, end_dim=-1)
x = x.view(*batch_shape, -1)
x = self.linear(x)
return x


if __name__ == '__main__':
torch.manual_seed(0)
image_shape = (1, 36, 29)
images = torch.randn(size=(11, *image_shape))
net = ConvNet(input_shape=image_shape, n_outputs=77)
im_shape = (1, 36, 29)
images = torch.randn(size=(11, *im_shape))
net = ConvNet(input_shape=im_shape, n_outputs=77)
out = net(images)
print(f'{out.shape = }')

0 comments on commit 3d2620b

Please sign in to comment.