From 8a8572bbf257c9fd397b7c8ce6163ca505f4ada5 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 25 Dec 2023 05:17:37 +0100 Subject: [PATCH] Remove Conditioner and NullConditioner classes --- .../autoregressive/conditioners/base.py | 21 ------------------- .../autoregressive/conditioners/graphical.py | 12 ----------- .../autoregressive/conditioners/recurrent.py | 0 .../finite/autoregressive/layers_base.py | 8 ++----- 4 files changed, 2 insertions(+), 39 deletions(-) delete mode 100644 normalizing_flows/bijections/finite/autoregressive/conditioners/base.py delete mode 100644 normalizing_flows/bijections/finite/autoregressive/conditioners/graphical.py delete mode 100644 normalizing_flows/bijections/finite/autoregressive/conditioners/recurrent.py diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioners/base.py b/normalizing_flows/bijections/finite/autoregressive/conditioners/base.py deleted file mode 100644 index 7571640..0000000 --- a/normalizing_flows/bijections/finite/autoregressive/conditioners/base.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -import torch.nn as nn - -from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform - - -class Conditioner(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor, transform: ConditionerTransform, context: torch.Tensor = None, **kwargs) -> torch.Tensor: - raise NotImplementedError - - -class NullConditioner(Conditioner): - def __init__(self): - # Each dimension affects only itself - super().__init__() - - def forward(self, x: torch.Tensor, transform: ConditionerTransform, context: torch.Tensor = None) -> torch.Tensor: - return transform(x, context=context).to(x) # (*batch_shape, *event_shape, n_parameters) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioners/graphical.py b/normalizing_flows/bijections/finite/autoregressive/conditioners/graphical.py deleted file mode 100644 index 0fcee46..0000000 --- a/normalizing_flows/bijections/finite/autoregressive/conditioners/graphical.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch - -from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner -from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform - - -class GraphicalConditioner(Conditioner): - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor, transform: ConditionerTransform, context: torch.Tensor = None) -> torch.Tensor: - pass diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioners/recurrent.py b/normalizing_flows/bijections/finite/autoregressive/conditioners/recurrent.py deleted file mode 100644 index e69de29..0000000 diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index fab59fe..afc7a00 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn -from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner, NullConditioner from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant, \ MADE from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import CouplingMask @@ -15,12 +14,10 @@ class AutoregressiveBijection(Bijection): def __init__(self, event_shape, - conditioner: Optional[Conditioner], transformer: Union[TensorTransformer, ScalarTransformer], conditioner_transform: ConditionerTransform, **kwargs): super().__init__(event_shape=event_shape) - self.conditioner = conditioner self.conditioner_transform = conditioner_transform self.transformer = transformer @@ -58,7 +55,7 @@ def __init__(self, coupling_mask: CouplingMask, conditioner_transform: ConditionerTransform, **kwargs): - super().__init__(coupling_mask.event_shape, None, transformer, conditioner_transform, **kwargs) + super().__init__(coupling_mask.event_shape, transformer, conditioner_transform, **kwargs) self.coupling_mask = coupling_mask assert conditioner_transform.input_event_shape == (coupling_mask.constant_event_size,) @@ -113,7 +110,7 @@ def __init__(self, context_shape=context_shape, **kwargs ) - super().__init__(transformer.event_shape, None, transformer, conditioner_transform) + super().__init__(transformer.event_shape, transformer, conditioner_transform) def apply_conditioner_transformer(self, inputs, context, forward: bool = True): h = self.conditioner_transform(inputs, context) @@ -161,7 +158,6 @@ class ElementwiseBijection(AutoregressiveBijection): def __init__(self, transformer: ScalarTransformer, fill_value: float = None): super().__init__( transformer.event_shape, - None, transformer, None )