Skip to content

Commit

Permalink
Add more multiscale architectures
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Jun 27, 2024
1 parent 96122ba commit d9aba8c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
7 changes: 6 additions & 1 deletion normalizing_flows/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,9 @@
Sylvester
)

from normalizing_flows.bijections.finite.multiscale.architectures import MultiscaleRealNVP
from normalizing_flows.bijections.finite.multiscale.architectures import (
MultiscaleRealNVP,
MultiscaleRQNSF,
MultiscaleLRSNSF,
MultiscaleNICE
)
40 changes: 39 additions & 1 deletion normalizing_flows/bijections/finite/multiscale/architectures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from normalizing_flows.bijections.finite.autoregressive.layers import ElementwiseAffine
from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine
from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine, Shift
from normalizing_flows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic
from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearRational
from normalizing_flows.bijections import BijectiveComposition
from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection

Expand Down Expand Up @@ -46,3 +48,39 @@ def __init__(self,
bijections = make_image_layers(event_shape, Affine, n_layers)
super().__init__(event_shape, bijections, **kwargs)
self.transformed_shape = bijections[-1].transformed_shape


class MultiscaleNICE(BijectiveComposition):
def __init__(self,
event_shape,
n_layers: int = 3,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_image_layers(event_shape, Shift, n_layers)
super().__init__(event_shape, bijections, **kwargs)
self.transformed_shape = bijections[-1].transformed_shape


class MultiscaleRQNSF(BijectiveComposition):
def __init__(self,
event_shape,
n_layers: int = 3,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_image_layers(event_shape, RationalQuadratic, n_layers)
super().__init__(event_shape, bijections, **kwargs)
self.transformed_shape = bijections[-1].transformed_shape


class MultiscaleLRSNSF(BijectiveComposition):
def __init__(self,
event_shape,
n_layers: int = 3,
**kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = make_image_layers(event_shape, LinearRational, n_layers)
super().__init__(event_shape, bijections, **kwargs)
self.transformed_shape = bijections[-1].transformed_shape

0 comments on commit d9aba8c

Please sign in to comment.