From e82627e9b21352ec0a1c9d446fe51af149eb0645 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 10 Aug 2024 21:21:34 +0200 Subject: [PATCH] Initialize flows to near identity --- .../finite/autoregressive/conditioning/transforms.py | 6 +++--- .../bijections/finite/autoregressive/layers_base.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py index be50ef5..dd2c3a5 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/transforms.py @@ -234,7 +234,7 @@ def __init__(self, self.sequential = nn.Sequential(*layers) def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): - return self.sequential(self.context_combiner(x, context)) + return self.sequential(self.context_combiner(x, context)) / 1000.0 class Linear(FeedForward): @@ -257,7 +257,7 @@ def __init__(self, event_size: int, hidden_size: int, block_size: int, nonlinear self.sequential = nn.Sequential(*layers) def forward(self, x): - return x + self.sequential(x) + return x + self.sequential(x) / 1000.0 def __init__(self, input_event_shape: torch.Size, @@ -289,7 +289,7 @@ def __init__(self, self.sequential = nn.Sequential(*layers) def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): - return self.sequential(self.context_combiner(x, context)) + return self.sequential(self.context_combiner(x, context)) / 1000.0 class CombinedConditioner(nn.Module): diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index fe20d30..477f6f2 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -180,7 +180,7 @@ def __init__(self, transformer: ScalarTransformer, fill_value: float = None): ) if fill_value is None: - self.value = nn.Parameter(torch.randn(*transformer.parameter_shape)) + self.value = nn.Parameter(torch.randn(*transformer.parameter_shape)) / 1000.0 else: self.value = nn.Parameter(torch.full(size=transformer.parameter_shape, fill_value=fill_value))