Skip to content

Commit

Permalink
Use maximum resolution in multiscale coupling
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Jun 24, 2024
1 parent 72e67fd commit 9c2620b
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions normalizing_flows/bijections/finite/multiscale/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Type, Union, Tuple

import torch
import torch.nn as nn

from normalizing_flows.bijections import BijectiveComposition, CouplingBijection
from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward
from normalizing_flows.bijections.finite.autoregressive.conditioning.transforms import FeedForward, ResidualFeedForward
from normalizing_flows.bijections.base import Bijection
from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer
from normalizing_flows.bijections.finite.multiscale.coupling import make_image_coupling
Expand All @@ -21,9 +22,10 @@ def __init__(self,
coupling_type='checkerboard' if not alternate else 'checkerboard_inverted'
)
transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,)))
conditioner_transform = FeedForward(
conditioner_transform = ResidualFeedForward(
input_event_shape=torch.Size((coupling.source_event_size,)),
parameter_shape=torch.Size(transformer.parameter_shape),
nonlinearity=nn.Tanh,
**kwargs
)
super().__init__(transformer, coupling, conditioner_transform, **kwargs)
Expand All @@ -40,9 +42,10 @@ def __init__(self,
coupling_type='channel_wise' if not alternate else 'channel_wise_inverted'
)
transformer = transformer_class(event_shape=torch.Size((coupling.target_event_size,)))
conditioner_transform = FeedForward(
conditioner_transform = ResidualFeedForward(
input_event_shape=torch.Size((coupling.source_event_size,)),
parameter_shape=torch.Size(transformer.parameter_shape),
nonlinearity=nn.Tanh,
**kwargs
)
super().__init__(transformer, coupling, conditioner_transform, **kwargs)
Expand Down Expand Up @@ -115,16 +118,32 @@ def __init__(self,
transformer_class: Type[TensorTransformer],
n_checkerboard_layers: int = 3,
n_channel_wise_layers: int = 3,
use_squeeze_layer: bool = True,
**kwargs):
channels, height, width = input_event_shape[-3:]
resolution = min(width, height) // 2
checkerboard_layers = [
CheckerboardCoupling(input_event_shape, transformer_class, alternate=i % 2 == 1)
CheckerboardCoupling(
input_event_shape,
transformer_class,
alternate=i % 2 == 1,
resolution=resolution
)
for i in range(n_checkerboard_layers)
]
squeeze_layer = Squeeze(input_event_shape)
channel_wise_layers = [
ChannelWiseCoupling(squeeze_layer.transformed_event_shape, transformer_class, alternate=i % 2 == 1)
ChannelWiseCoupling(
squeeze_layer.transformed_event_shape,
transformer_class,
alternate=i % 2 == 1,
resolution=resolution
)
for i in range(n_channel_wise_layers)
]
layers = [*checkerboard_layers, squeeze_layer, *channel_wise_layers]
if use_squeeze_layer:
layers = [*checkerboard_layers, squeeze_layer, *channel_wise_layers]
else:
layers = [*checkerboard_layers, *channel_wise_layers]
super().__init__(input_event_shape, layers, **kwargs)
self.transformed_shape = squeeze_layer.transformed_event_shape
self.transformed_shape = squeeze_layer.transformed_event_shape if use_squeeze_layer else input_event_shape

0 comments on commit 9c2620b

Please sign in to comment.