diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py index ccfcef9..bba5d1b 100644 --- a/normalizing_flows/bijections/finite/multiscale/architectures.py +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -123,9 +123,9 @@ def make_image_layers_non_factored(event_shape, MultiscaleBijection( input_event_shape=bijections[-1].transformed_shape, transformer_class=transformer_class, - n_checkerboard_layers=4, + n_checkerboard_layers=0, squeeze_layer=False, - n_channel_wise_layers=0, + n_channel_wise_layers=2, **kwargs ) ) diff --git a/normalizing_flows/bijections/finite/multiscale/base.py b/normalizing_flows/bijections/finite/multiscale/base.py index 5350e91..cdf03c0 100644 --- a/normalizing_flows/bijections/finite/multiscale/base.py +++ b/normalizing_flows/bijections/finite/multiscale/base.py @@ -83,7 +83,7 @@ def __init__(self, self.network = ConvNet( input_shape=input_event_shape, n_outputs=self.n_transformer_parameters, - kernels=kernels, + kernels=kernels ) def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: @@ -260,8 +260,8 @@ class MultiscaleBijection(BijectiveComposition): def __init__(self, input_event_shape, transformer_class: Type[TensorTransformer], - n_checkerboard_layers: int = 3, - n_channel_wise_layers: int = 3, + n_checkerboard_layers: int = 2, + n_channel_wise_layers: int = 2, use_squeeze_layer: bool = True, use_resnet: bool = False, **kwargs): diff --git a/normalizing_flows/neural_networks/convnet.py b/normalizing_flows/neural_networks/convnet.py index cff8cf1..3097566 100644 --- a/normalizing_flows/neural_networks/convnet.py +++ b/normalizing_flows/neural_networks/convnet.py @@ -1,9 +1,46 @@ +import math from typing import Tuple import torch import torch.nn as nn +class ConvModifier(nn.Module): + """ + Convolutional layer that transforms an image with size (c, h, w) into an image with size (4, 32, 32). + """ + + def __init__(self, + image_shape, + c_target: int = 4, + h_target: int = 32, + w_target: int = 32): + super().__init__() + c, h, w = image_shape + if h >= h_target: + kernel_height = h - h_target + 1 + padding_height = 0 + else: + kernel_height = 1 if (h_target - h) % 2 == 0 else 2 + padding_height = ((h_target - h) + kernel_height - 1) // 2 + if w >= w_target: + kernel_width = w - w_target + 1 + padding_width = 0 + else: + kernel_width = 1 if (w_target - w) % 2 == 0 else 2 + padding_width = ((w_target - w) + kernel_width - 1) // 2 + self.conv = nn.Conv2d( + in_channels=c, + out_channels=c_target, + kernel_size=(kernel_height, kernel_width), + padding=(padding_height, padding_width) + ) + self.output_shape = (c_target, h_target, w_target) + + def forward(self, x): + return self.conv(x) + + class ConvNet(nn.Module): class ConvNetBlock(nn.Module): def __init__(self, in_channels, out_channels, input_height, input_width, use_pooling: bool = True): @@ -27,20 +64,21 @@ def __init__(self, input_shape, n_outputs: int, kernels: Tuple[int, ...] = None) :param n_outputs: """ super().__init__() - channels, height, width = input_shape if kernels is None: - kernels = (64, 64, 32, 4) + kernels = (8, 8, 4) else: assert len(kernels) >= 1 + reducer = ConvModifier(input_shape) + blocks = [ self.ConvNetBlock( - in_channels=channels, + in_channels=reducer.output_shape[0], out_channels=kernels[0], - input_height=height, - input_width=width, - use_pooling=min(height, width) >= 2 + input_height=reducer.output_shape[1], + input_width=reducer.output_shape[2], + use_pooling=min(reducer.output_shape[1], reducer.output_shape[2]) >= 2 ) ] for i in range(len(kernels) - 1): @@ -53,24 +91,37 @@ def __init__(self, input_shape, n_outputs: int, kernels: Tuple[int, ...] = None) use_pooling=min(blocks[i].output_shape[1], blocks[i].output_shape[2]) >= 2 ) ) - self.blocks = nn.ModuleList(blocks) + + self.blocks = nn.ModuleList([reducer] + blocks) + + hidden_size_sqrt: int = 4 + hidden_size = hidden_size_sqrt ** 2 + self.blocks.append( + ConvModifier( + image_shape=blocks[-1].output_shape, + c_target=1, + h_target=hidden_size_sqrt, + w_target=hidden_size_sqrt + ) + ) self.linear = nn.Linear( - in_features=int(torch.prod(torch.as_tensor(self.blocks[-1].output_shape))), + in_features=hidden_size, out_features=n_outputs ) def forward(self, x): + batch_shape = x.shape[:-3] for block in self.blocks: x = block(x) - x = x.flatten(start_dim=1, end_dim=-1) + x = x.view(*batch_shape, -1) x = self.linear(x) return x if __name__ == '__main__': torch.manual_seed(0) - image_shape = (1, 36, 29) - images = torch.randn(size=(11, *image_shape)) - net = ConvNet(input_shape=image_shape, n_outputs=77) + im_shape = (1, 36, 29) + images = torch.randn(size=(11, *im_shape)) + net = ConvNet(input_shape=im_shape, n_outputs=77) out = net(images) print(f'{out.shape = }')