diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index cca7e34..ca31f0a 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -180,7 +180,7 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class UMNNMAF(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 1, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 15b74ff..f5ba2c7 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -223,8 +223,8 @@ class UMNNMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, - n_hidden_layers: int = 1, - hidden_dim: int = 5, + n_hidden_layers: int = None, + hidden_dim: int = None, **kwargs): transformer: ScalarTransformer = UnconstrainedMonotonicNeuralNetwork( event_shape=event_shape, diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py index 3036dc7..2290662 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py @@ -32,14 +32,19 @@ class UnconstrainedMonotonicNeuralNetwork(UnconstrainedMonotonicTransformer): """ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], - n_hidden_layers: int = 2, + n_hidden_layers: int = None, hidden_dim: int = None): super().__init__(event_shape, g=self.neural_network_forward, c=torch.tensor(-100.0)) + + if n_hidden_layers is None: + n_hidden_layers = 1 self.n_hidden_layers = n_hidden_layers + if hidden_dim is None: - hidden_dim = max(5 * int(math.log(self.n_dim)), 4) + hidden_dim = max(int(math.log(self.n_dim)), 4) self.hidden_dim = hidden_dim - self.const = 1000 # for stability + + self.const = 1 # for stability # weight, bias have self.hidden_dim elements self.n_input_params = 2 * self.hidden_dim @@ -118,10 +123,8 @@ def neural_network_forward(inputs, parameters: List[torch.Tensor]): def base_forward_1d(self, x: torch.Tensor, params: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: x_r = x.view(-1, 1, 1) - integral_flat = self.integral(x_r, params) - log_det_flat = self.g(x_r, params).log() # We can apply log since g is always positive - output = integral_flat.view_as(x) - log_det = log_det_flat.view_as(x) + output = self.integral(x_r, params).view_as(x) + log_det = self.g(x_r, params).log().view_as(x) # We can apply log since g is always positive return output, log_det def inverse_1d(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: