diff --git a/normalizing_flows/bijections/finite/residual/iterative.py b/normalizing_flows/bijections/finite/residual/iterative.py index c778350..36009c2 100644 --- a/normalizing_flows/bijections/finite/residual/iterative.py +++ b/normalizing_flows/bijections/finite/residual/iterative.py @@ -1,3 +1,4 @@ +import math from typing import Union, Tuple import torch @@ -52,7 +53,10 @@ def forward(self, x): class SpectralNeuralNetwork(nn.Sequential): - def __init__(self, n_dim: int, n_hidden: int = 100, n_hidden_layers: int = 2, **kwargs): + def __init__(self, n_dim: int, n_hidden: int = None, n_hidden_layers: int = 2, **kwargs): + if n_hidden is None: + n_hidden = int(max(math.log(n_dim), 4)) + layers = [] if n_hidden_layers == 0: layers = [SpectralLinear(n_dim, n_dim, **kwargs)]