diff --git a/normalizing_flows/architectures.py b/normalizing_flows/architectures.py index 90c9561..b441edf 100644 --- a/normalizing_flows/architectures.py +++ b/normalizing_flows/architectures.py @@ -26,4 +26,9 @@ Sylvester ) -from normalizing_flows.bijections.finite.multiscale.architectures import MultiscaleRealNVP +from normalizing_flows.bijections.finite.multiscale.architectures import ( + MultiscaleRealNVP, + MultiscaleRQNSF, + MultiscaleLRSNSF, + MultiscaleNICE +) diff --git a/normalizing_flows/bijections/finite/multiscale/architectures.py b/normalizing_flows/bijections/finite/multiscale/architectures.py index 21f7d8f..8f5a49e 100644 --- a/normalizing_flows/bijections/finite/multiscale/architectures.py +++ b/normalizing_flows/bijections/finite/multiscale/architectures.py @@ -1,5 +1,7 @@ from normalizing_flows.bijections.finite.autoregressive.layers import ElementwiseAffine -from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine, Shift +from normalizing_flows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic +from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearRational from normalizing_flows.bijections import BijectiveComposition from normalizing_flows.bijections.finite.multiscale.base import MultiscaleBijection @@ -46,3 +48,39 @@ def __init__(self, bijections = make_image_layers(event_shape, Affine, n_layers) super().__init__(event_shape, bijections, **kwargs) self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleNICE(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = 3, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, Shift, n_layers) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleRQNSF(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = 3, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, RationalQuadratic, n_layers) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape + + +class MultiscaleLRSNSF(BijectiveComposition): + def __init__(self, + event_shape, + n_layers: int = 3, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_image_layers(event_shape, LinearRational, n_layers) + super().__init__(event_shape, bijections, **kwargs) + self.transformed_shape = bijections[-1].transformed_shape