Skip to content

Commit

Permalink
Add Glow architectures
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Sep 2, 2024
1 parent f3b22ca commit b2c09c2
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions torchflows/bijections/finite/multiscale/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,54 @@ def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layer
n_blocks=n_layers,
**kwargs
)


class DenseSigmoidGlow(MultiscaleBijection):
def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
if n_layers is None:
n_layers = automatically_determine_n_layers(event_shape)
check_image_shape_for_multiscale_flow(event_shape, n_layers)
super().__init__(
event_shape=event_shape,
transformer_class=DenseSigmoid,
checkerboard_class=GlowCheckerboardCoupling,
channel_wise_class=GlowChannelWiseCoupling,
n_blocks=n_layers,
**kwargs
)


class DeepSigmoidGlow(MultiscaleBijection):
def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
if n_layers is None:
n_layers = automatically_determine_n_layers(event_shape)
check_image_shape_for_multiscale_flow(event_shape, n_layers)
super().__init__(
event_shape=event_shape,
transformer_class=DeepSigmoid,
checkerboard_class=GlowCheckerboardCoupling,
channel_wise_class=GlowChannelWiseCoupling,
n_blocks=n_layers,
**kwargs
)


class DeepDenseSigmoidGlow(MultiscaleBijection):
def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
if n_layers is None:
n_layers = automatically_determine_n_layers(event_shape)
check_image_shape_for_multiscale_flow(event_shape, n_layers)
super().__init__(
event_shape=event_shape,
transformer_class=DeepDenseSigmoid,
checkerboard_class=GlowCheckerboardCoupling,
channel_wise_class=GlowChannelWiseCoupling,
n_blocks=n_layers,
**kwargs
)

0 comments on commit b2c09c2

Please sign in to comment.