Skip to content

Commit

Permalink
Rewrite elementwise bijection to avoid null conditioner
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Dec 25, 2023
1 parent 9f22ac8 commit 8ac96f8
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions normalizing_flows/bijections/finite/autoregressive/layers_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Tuple, Optional, Union

import torch
import torch.nn as nn

from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner, NullConditioner
from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant, \
Expand Down Expand Up @@ -160,7 +161,24 @@ class ElementwiseBijection(AutoregressiveBijection):
def __init__(self, transformer: ScalarTransformer, fill_value: float = None):
super().__init__(
transformer.event_shape,
NullConditioner(),
None,
transformer,
Constant(transformer.event_shape, transformer.parameter_shape, fill_value=fill_value)
None
)

if fill_value is None:
self.value = nn.Parameter(torch.randn(*transformer.parameter_shape))
else:
self.value = nn.Parameter(torch.full(size=transformer.parameter_shape, fill_value=fill_value))

def prepare_h(self, batch_shape):
tmp = self.value[[None] * len(batch_shape)]
return tmp.repeat(*batch_shape, *([1] * len(self.transformer.parameter_shape)))

def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
h = self.prepare_h(get_batch_shape(x, self.event_shape))
return self.transformer.forward(x, h)

def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
h = self.prepare_h(get_batch_shape(z, self.event_shape))
return self.transformer.inverse(z, h)

0 comments on commit 8ac96f8

Please sign in to comment.