Skip to content

Commit

Permalink
Remove Conditioner and NullConditioner classes
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Dec 25, 2023
1 parent 8ac96f8 commit 8a8572b
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 39 deletions.

This file was deleted.

This file was deleted.

Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.nn as nn

from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner, NullConditioner
from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant, \
MADE
from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import CouplingMask
Expand All @@ -15,12 +14,10 @@
class AutoregressiveBijection(Bijection):
def __init__(self,
event_shape,
conditioner: Optional[Conditioner],
transformer: Union[TensorTransformer, ScalarTransformer],
conditioner_transform: ConditionerTransform,
**kwargs):
super().__init__(event_shape=event_shape)
self.conditioner = conditioner
self.conditioner_transform = conditioner_transform
self.transformer = transformer

Expand Down Expand Up @@ -58,7 +55,7 @@ def __init__(self,
coupling_mask: CouplingMask,
conditioner_transform: ConditionerTransform,
**kwargs):
super().__init__(coupling_mask.event_shape, None, transformer, conditioner_transform, **kwargs)
super().__init__(coupling_mask.event_shape, transformer, conditioner_transform, **kwargs)
self.coupling_mask = coupling_mask

assert conditioner_transform.input_event_shape == (coupling_mask.constant_event_size,)
Expand Down Expand Up @@ -113,7 +110,7 @@ def __init__(self,
context_shape=context_shape,
**kwargs
)
super().__init__(transformer.event_shape, None, transformer, conditioner_transform)
super().__init__(transformer.event_shape, transformer, conditioner_transform)

def apply_conditioner_transformer(self, inputs, context, forward: bool = True):
h = self.conditioner_transform(inputs, context)
Expand Down Expand Up @@ -161,7 +158,6 @@ class ElementwiseBijection(AutoregressiveBijection):
def __init__(self, transformer: ScalarTransformer, fill_value: float = None):
super().__init__(
transformer.event_shape,
None,
transformer,
None
)
Expand Down

0 comments on commit 8a8572b

Please sign in to comment.