Skip to content

Commit

Permalink
Initialize flows to near identity
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Aug 10, 2024
1 parent 03fe60f commit e82627e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def __init__(self,
self.sequential = nn.Sequential(*layers)

def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None):
return self.sequential(self.context_combiner(x, context))
return self.sequential(self.context_combiner(x, context)) / 1000.0


class Linear(FeedForward):
Expand All @@ -257,7 +257,7 @@ def __init__(self, event_size: int, hidden_size: int, block_size: int, nonlinear
self.sequential = nn.Sequential(*layers)

def forward(self, x):
return x + self.sequential(x)
return x + self.sequential(x) / 1000.0

def __init__(self,
input_event_shape: torch.Size,
Expand Down Expand Up @@ -289,7 +289,7 @@ def __init__(self,
self.sequential = nn.Sequential(*layers)

def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None):
return self.sequential(self.context_combiner(x, context))
return self.sequential(self.context_combiner(x, context)) / 1000.0


class CombinedConditioner(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(self, transformer: ScalarTransformer, fill_value: float = None):
)

if fill_value is None:
self.value = nn.Parameter(torch.randn(*transformer.parameter_shape))
self.value = nn.Parameter(torch.randn(*transformer.parameter_shape)) / 1000.0
else:
self.value = nn.Parameter(torch.full(size=transformer.parameter_shape, fill_value=fill_value))

Expand Down

0 comments on commit e82627e

Please sign in to comment.