diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 83fa848..fab59fe 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -1,6 +1,7 @@ from typing import Tuple, Optional, Union import torch +import torch.nn as nn from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner, NullConditioner from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant, \ @@ -160,7 +161,24 @@ class ElementwiseBijection(AutoregressiveBijection): def __init__(self, transformer: ScalarTransformer, fill_value: float = None): super().__init__( transformer.event_shape, - NullConditioner(), + None, transformer, - Constant(transformer.event_shape, transformer.parameter_shape, fill_value=fill_value) + None ) + + if fill_value is None: + self.value = nn.Parameter(torch.randn(*transformer.parameter_shape)) + else: + self.value = nn.Parameter(torch.full(size=transformer.parameter_shape, fill_value=fill_value)) + + def prepare_h(self, batch_shape): + tmp = self.value[[None] * len(batch_shape)] + return tmp.repeat(*batch_shape, *([1] * len(self.transformer.parameter_shape))) + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + h = self.prepare_h(get_batch_shape(x, self.event_shape)) + return self.transformer.forward(x, h) + + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + h = self.prepare_h(get_batch_shape(z, self.event_shape)) + return self.transformer.inverse(z, h)