From e43697921054b33552429d2c9f0bf6ef227c6390 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 23 Aug 2024 14:19:38 +0200 Subject: [PATCH] Update default residual flow hyperparameters --- torchflows/bijections/finite/residual/architectures.py | 6 +++--- torchflows/bijections/finite/residual/proximal.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchflows/bijections/finite/residual/architectures.py b/torchflows/bijections/finite/residual/architectures.py index d44194d..acfb1a3 100644 --- a/torchflows/bijections/finite/residual/architectures.py +++ b/torchflows/bijections/finite/residual/architectures.py @@ -17,7 +17,7 @@ class InvertibleResNet(ResidualComposition): Reference: Behrmann et al. "Invertible Residual Networks" (2019); https://arxiv.org/abs/1811.00995. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 16, **kwargs): + def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): blocks = [ InvertibleResNetBlock(event_shape=event_shape, context_shape=context_shape, **kwargs) for _ in range(n_layers) @@ -30,7 +30,7 @@ class ResFlow(ResidualComposition): Reference: Chen et al. "Residual Flows for Invertible Generative Modeling" (2020); https://arxiv.org/abs/1906.02735. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 16, **kwargs): + def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): blocks = [ ResFlowBlock(event_shape=event_shape, context_shape=context_shape, **kwargs) for _ in range(n_layers) @@ -43,7 +43,7 @@ class ProximalResFlow(ResidualComposition): Reference: Hertrich "Proximal Residual Flows for Bayesian Inverse Problems" (2022); https://arxiv.org/abs/2211.17158. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 16, **kwargs): + def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): blocks = [ ProximalResFlowBlock(event_shape=event_shape, context_shape=context_shape, gamma=0.01, **kwargs) for _ in range(n_layers) diff --git a/torchflows/bijections/finite/residual/proximal.py b/torchflows/bijections/finite/residual/proximal.py index 916081a..7b71d44 100644 --- a/torchflows/bijections/finite/residual/proximal.py +++ b/torchflows/bijections/finite/residual/proximal.py @@ -70,8 +70,8 @@ def __init__(self, event_size: int, hidden_size: int, act: ProximityOperator): # Initialize t_tilde close to identity divisor = max(self.event_size ** 2, 100) - self.b = nn.Parameter(torch.randn(self.hidden_size) / divisor) - self.delta_t_tilde = nn.Parameter(torch.randn(self.hidden_size, self.event_size) / divisor) + self.b = nn.Parameter(torch.randn(size=(self.hidden_size,)) / divisor) + self.delta_t_tilde = nn.Parameter(torch.randn(size=(self.hidden_size, self.event_size)) / divisor) self.act = act @property @@ -109,7 +109,7 @@ def __init__(self, event_size: int, n_layers: int = 1, hidden_size: int = None, if act is None: act = TanH() if hidden_size is None: - hidden_size = max(math.log(event_size), 4) + hidden_size = int(max(math.log(event_size), 4)) super().__init__(*[PNNBlock(event_size, hidden_size, act) for _ in range(n_layers)]) self.n_layers = n_layers self.act = act