Skip to content

Commit

Permalink
Update default residual flow hyperparameters
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Aug 23, 2024
1 parent 26166e5 commit e436979
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions torchflows/bijections/finite/residual/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class InvertibleResNet(ResidualComposition):
Reference: Behrmann et al. "Invertible Residual Networks" (2019); https://arxiv.org/abs/1811.00995.
"""
def __init__(self, event_shape, context_shape=None, n_layers: int = 16, **kwargs):
def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs):
blocks = [
InvertibleResNetBlock(event_shape=event_shape, context_shape=context_shape, **kwargs)
for _ in range(n_layers)
Expand All @@ -30,7 +30,7 @@ class ResFlow(ResidualComposition):
Reference: Chen et al. "Residual Flows for Invertible Generative Modeling" (2020); https://arxiv.org/abs/1906.02735.
"""
def __init__(self, event_shape, context_shape=None, n_layers: int = 16, **kwargs):
def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs):
blocks = [
ResFlowBlock(event_shape=event_shape, context_shape=context_shape, **kwargs)
for _ in range(n_layers)
Expand All @@ -43,7 +43,7 @@ class ProximalResFlow(ResidualComposition):
Reference: Hertrich "Proximal Residual Flows for Bayesian Inverse Problems" (2022); https://arxiv.org/abs/2211.17158.
"""
def __init__(self, event_shape, context_shape=None, n_layers: int = 16, **kwargs):
def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs):
blocks = [
ProximalResFlowBlock(event_shape=event_shape, context_shape=context_shape, gamma=0.01, **kwargs)
for _ in range(n_layers)
Expand Down
6 changes: 3 additions & 3 deletions torchflows/bijections/finite/residual/proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def __init__(self, event_size: int, hidden_size: int, act: ProximityOperator):
# Initialize t_tilde close to identity

divisor = max(self.event_size ** 2, 100)
self.b = nn.Parameter(torch.randn(self.hidden_size) / divisor)
self.delta_t_tilde = nn.Parameter(torch.randn(self.hidden_size, self.event_size) / divisor)
self.b = nn.Parameter(torch.randn(size=(self.hidden_size,)) / divisor)
self.delta_t_tilde = nn.Parameter(torch.randn(size=(self.hidden_size, self.event_size)) / divisor)
self.act = act

@property
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(self, event_size: int, n_layers: int = 1, hidden_size: int = None,
if act is None:
act = TanH()
if hidden_size is None:
hidden_size = max(math.log(event_size), 4)
hidden_size = int(max(math.log(event_size), 4))
super().__init__(*[PNNBlock(event_size, hidden_size, act) for _ in range(n_layers)])
self.n_layers = n_layers
self.act = act
Expand Down

0 comments on commit e436979

Please sign in to comment.