Skip to content

Commit

Permalink
Add ResNet conditioner, change default kernels for convnet conditioner
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Jul 7, 2024
1 parent da6fd75 commit 5e430ea
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 30 deletions.
38 changes: 26 additions & 12 deletions normalizing_flows/bijections/finite/multiscale/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def automatically_determine_n_layers(event_shape):

def make_factored_image_layers(event_shape,
transformer_class,
n_layers: int = None):
n_layers: int = None,
**kwargs):
"""
Creates a list of image transformations consisting of coupling layers and squeeze layers.
After each coupling, squeeze, coupling mapping, half of the channels are kept as is (not transformed anymore).
Expand All @@ -63,7 +64,8 @@ def make_factored_image_layers(event_shape,
def recursive_layer_builder(event_shape_, n_layers_):
msb = MultiscaleBijection(
input_event_shape=event_shape_,
transformer_class=transformer_class
transformer_class=transformer_class,
**kwargs
)
if n_layers_ == 1:
return msb
Expand Down Expand Up @@ -95,7 +97,8 @@ def recursive_layer_builder(event_shape_, n_layers_):

def make_image_layers_non_factored(event_shape,
transformer_class,
n_layers: int = None):
n_layers: int = None,
**kwargs):
"""
Returns a list of bijections for transformations of images with multiple channels.
Expand All @@ -112,7 +115,8 @@ def make_image_layers_non_factored(event_shape,
bijections.append(
MultiscaleBijection(
input_event_shape=bijections[-1].transformed_shape,
transformer_class=transformer_class
transformer_class=transformer_class,
**kwargs
)
)
bijections.append(
Expand All @@ -121,7 +125,8 @@ def make_image_layers_non_factored(event_shape,
transformer_class=transformer_class,
n_checkerboard_layers=4,
squeeze_layer=False,
n_channel_wise_layers=0
n_channel_wise_layers=0,
**kwargs
)
)
bijections.append(ElementwiseAffine(event_shape=bijections[-1].transformed_shape))
Expand All @@ -140,10 +145,11 @@ def __init__(self,
event_shape,
n_layers: int = None,
factored: bool = False,
use_resnet: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_image_layers(event_shape, Affine, n_layers, factored=factored)
bijections = make_image_layers(event_shape, Affine, n_layers, factored=factored, use_resnet=use_resnet)
super().__init__(event_shape, bijections, **kwargs)
self.transformed_shape = bijections[-1].transformed_shape

Expand All @@ -153,10 +159,11 @@ def __init__(self,
event_shape,
n_layers: int = None,
factored: bool = False,
use_resnet: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_image_layers(event_shape, Shift, n_layers, factored=factored)
bijections = make_image_layers(event_shape, Shift, n_layers, factored=factored, use_resnet=use_resnet)
super().__init__(event_shape, bijections, **kwargs)
self.transformed_shape = bijections[-1].transformed_shape

Expand All @@ -166,10 +173,12 @@ def __init__(self,
event_shape,
n_layers: int = None,
factored: bool = False,
use_resnet: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_image_layers(event_shape, RationalQuadratic, n_layers, factored=factored)
bijections = make_image_layers(event_shape, RationalQuadratic, n_layers, factored=factored,
use_resnet=use_resnet)
super().__init__(event_shape, bijections, **kwargs)
self.transformed_shape = bijections[-1].transformed_shape

Expand All @@ -179,10 +188,11 @@ def __init__(self,
event_shape,
n_layers: int = None,
factored: bool = False,
use_resnet: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_image_layers(event_shape, LinearRational, n_layers, factored=factored)
bijections = make_image_layers(event_shape, LinearRational, n_layers, factored=factored, use_resnet=use_resnet)
super().__init__(event_shape, bijections, **kwargs)
self.transformed_shape = bijections[-1].transformed_shape

Expand All @@ -192,10 +202,11 @@ def __init__(self,
event_shape,
n_layers: int = None,
factored: bool = False,
use_resnet: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_image_layers(event_shape, DeepSigmoid, n_layers, factored=factored)
bijections = make_image_layers(event_shape, DeepSigmoid, n_layers, factored=factored, use_resnet=use_resnet)
super().__init__(event_shape, bijections, **kwargs)
self.transformed_shape = bijections[-1].transformed_shape

Expand All @@ -205,10 +216,12 @@ def __init__(self,
event_shape,
n_layers: int = None,
factored: bool = False,
use_resnet: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_image_layers(event_shape, DeepDenseSigmoid, n_layers, factored=factored)
bijections = make_image_layers(event_shape, DeepDenseSigmoid, n_layers, factored=factored,
use_resnet=use_resnet)
super().__init__(event_shape, bijections, **kwargs)
self.transformed_shape = bijections[-1].transformed_shape

Expand All @@ -218,9 +231,10 @@ def __init__(self,
event_shape,
n_layers: int = None,
factored: bool = False,
use_resnet: bool = False,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_image_layers(event_shape, DenseSigmoid, n_layers, factored=factored)
bijections = make_image_layers(event_shape, DenseSigmoid, n_layers, factored=factored, use_resnet=use_resnet)
super().__init__(event_shape, bijections, **kwargs)
self.transformed_shape = bijections[-1].transformed_shape
48 changes: 41 additions & 7 deletions normalizing_flows/bijections/finite/multiscale/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from normalizing_flows.bijections.finite.multiscale.coupling import make_image_coupling, Checkerboard, \
ChannelWiseHalfSplit
from normalizing_flows.neural_networks.convnet import ConvNet
from normalizing_flows.neural_networks.resnet import make_resnet18
from normalizing_flows.utils import get_batch_shape


Expand Down Expand Up @@ -89,16 +90,46 @@ def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None) -> t
return self.network(x)


class ResNetConditioner(ConditionerTransform):
def __init__(self,
input_event_shape: torch.Size,
parameter_shape: torch.Size,
**kwargs):
super().__init__(
input_event_shape=input_event_shape,
context_shape=None,
parameter_shape=parameter_shape,
**kwargs
)
self.network = make_resnet18(
image_shape=input_event_shape,
n_outputs=self.n_transformer_parameters
)

def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:
return self.network(x)


class ConvolutionalCouplingBijection(CouplingBijection):
def __init__(self,
transformer: TensorTransformer,
coupling: Union[Checkerboard, ChannelWiseHalfSplit],
conditioner='convnet',
**kwargs):
conditioner_transform = ConvNetConditioner(
input_event_shape=coupling.constant_shape,
parameter_shape=transformer.parameter_shape,
**kwargs
)
if conditioner == 'convnet':
conditioner_transform = ConvNetConditioner(
input_event_shape=coupling.constant_shape,
parameter_shape=transformer.parameter_shape,
**kwargs
)
elif conditioner == 'resnet':
conditioner_transform = ResNetConditioner(
input_event_shape=coupling.constant_shape,
parameter_shape=transformer.parameter_shape,
**kwargs
)
else:
raise ValueError(f'Unknown conditioner: {conditioner}')
super().__init__(transformer, coupling, conditioner_transform, **kwargs)
self.coupling = coupling

Expand Down Expand Up @@ -232,12 +263,14 @@ def __init__(self,
n_checkerboard_layers: int = 3,
n_channel_wise_layers: int = 3,
use_squeeze_layer: bool = True,
use_resnet: bool = False,
**kwargs):
checkerboard_layers = [
CheckerboardCoupling(
input_event_shape,
transformer_class,
alternate=i % 2 == 1
alternate=i % 2 == 1,
conditioner='resnet' if use_resnet else 'convnet'
)
for i in range(n_checkerboard_layers)
]
Expand All @@ -246,7 +279,8 @@ def __init__(self,
ChannelWiseCoupling(
squeeze_layer.transformed_event_shape,
transformer_class,
alternate=i % 2 == 1
alternate=i % 2 == 1,
conditioner='resnet' if use_resnet else 'convnet'
)
for i in range(n_channel_wise_layers)
]
Expand Down
2 changes: 1 addition & 1 deletion normalizing_flows/neural_networks/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, input_shape, n_outputs: int, kernels: Tuple[int, ...] = None)
channels, height, width = input_shape

if kernels is None:
kernels = (4, 8, 16, 24)
kernels = (64, 64, 32, 4)
else:
assert len(kernels) >= 1

Expand Down
23 changes: 13 additions & 10 deletions normalizing_flows/neural_networks/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,27 @@ def forward(self, x):
return out


def make_resnet18(image_shape, out):
return ResNet(*image_shape, BasicBlock, num_blocks=[2, 2, 2, 2], n_outputs=out)
def make_resnet18(image_shape, n_outputs):
return ResNet(*image_shape, BasicBlock, num_blocks=[2, 2, 2, 2], n_outputs=n_outputs)


def make_resnet34(image_shape, out):
return ResNet(*image_shape, BasicBlock, num_blocks=[3, 4, 6, 3], n_outputs=out)
def make_resnet34(image_shape, n_outputs):
return ResNet(*image_shape, BasicBlock, num_blocks=[3, 4, 6, 3], n_outputs=n_outputs)


def make_resnet50(image_shape, out):
return ResNet(*image_shape, Bottleneck, num_blocks=[3, 4, 6, 3], n_outputs=out)
def make_resnet50(image_shape, n_outputs):
# TODO fix error regarding image shape
return ResNet(*image_shape, Bottleneck, num_blocks=[3, 4, 6, 3], n_outputs=n_outputs)


def make_resnet101(image_shape, out):
return ResNet(*image_shape, Bottleneck, num_blocks=[3, 4, 23, 3], n_outputs=out)
def make_resnet101(image_shape, n_outputs):
# TODO fix error regarding image shape
return ResNet(*image_shape, Bottleneck, num_blocks=[3, 4, 23, 3], n_outputs=n_outputs)


def make_resnet152(image_shape, out):
return ResNet(*image_shape, Bottleneck, num_blocks=[3, 8, 36, 3], n_outputs=out)
def make_resnet152(image_shape, n_outputs):
# TODO fix error regarding image shape
return ResNet(*image_shape, Bottleneck, num_blocks=[3, 8, 36, 3], n_outputs=n_outputs)


if __name__ == '__main__':
Expand Down

0 comments on commit 5e430ea

Please sign in to comment.