Skip to content

Commit

Permalink
Add default Planar, Radial, and Sylvester architectures
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Nov 9, 2023
1 parent 0c05698 commit c57ef12
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
9 changes: 8 additions & 1 deletion normalizing_flows/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,11 @@
from normalizing_flows.bijections.continuous.ffjord import FFJORD
from normalizing_flows.bijections.continuous.otflow import OTFlow

from normalizing_flows.bijections.finite.residual.architectures import ResFlow, ProximalResFlow, InvertibleResNet
from normalizing_flows.bijections.finite.residual.architectures import (
ResFlow,
ProximalResFlow,
InvertibleResNet,
Planar,
Radial,
Sylvester
)
42 changes: 42 additions & 0 deletions normalizing_flows/bijections/finite/residual/architectures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from typing import Union, Tuple

import torch

from normalizing_flows.bijections import Affine
from normalizing_flows.bijections.base import BijectiveComposition
from normalizing_flows.bijections.finite.residual.base import ResidualComposition
from normalizing_flows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock
from normalizing_flows.bijections.finite.residual.proximal import ProximalResFlowBlock
from normalizing_flows.bijections.finite.residual.planar import Planar
from normalizing_flows.bijections.finite.residual.radial import Radial
from normalizing_flows.bijections.finite.residual.sylvester import Sylvester


class InvertibleResNet(ResidualComposition):
Expand All @@ -22,3 +31,36 @@ def __init__(self, event_shape, context_shape=None, n_layers: int = 16, **kwargs
block = ProximalResFlowBlock(event_shape=event_shape, context_shape=context_shape, **kwargs)
blocks = [block for _ in range(n_layers)] # The same block
super().__init__(blocks)


class PlanarFlow(BijectiveComposition):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2):
if n_layers < 1:
raise ValueError(f"Flow needs at least one layer, but got {n_layers}")
super().__init__(event_shape, [
Affine(event_shape),
*[Planar(event_shape) for _ in range(n_layers)],
Affine(event_shape)
])


class RadialFlow(BijectiveComposition):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2):
if n_layers < 1:
raise ValueError(f"Flow needs at least one layer, but got {n_layers}")
super().__init__(event_shape, [
Affine(event_shape),
*[Radial(event_shape) for _ in range(n_layers)],
Affine(event_shape)
])


class SylvesterFlow(BijectiveComposition):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2, **kwargs):
if n_layers < 1:
raise ValueError(f"Flow needs at least one layer, but got {n_layers}")
super().__init__(event_shape, [
Affine(event_shape),
*[Sylvester(event_shape, **kwargs) for _ in range(n_layers)],
Affine(event_shape)
])

0 comments on commit c57ef12

Please sign in to comment.