Skip to content

Commit

Permalink
Merge pull request #5 from davidnabergoj/dev
Browse files Browse the repository at this point in the history
Transformer refactor
  • Loading branch information
davidnabergoj authored Dec 2, 2023
2 parents 845f1f4 + 41ba047 commit e71c6f9
Show file tree
Hide file tree
Showing 29 changed files with 788 additions and 499 deletions.
2 changes: 1 addition & 1 deletion normalizing_flows/bijections/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def batch_inverse(self, x: torch.Tensor, batch_size: int, context: torch.Tensor
return outputs, log_dets


def invert(bijection: Bijection) -> Bijection:
def invert(bijection):
"""
Swap the forward and inverse methods of the input bijection.
"""
Expand Down
8 changes: 7 additions & 1 deletion normalizing_flows/bijections/continuous/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.nn as nn
from torchdiffeq import odeint

from normalizing_flows.bijections.base import Bijection
from normalizing_flows.bijections.continuous.layers import DiffEqLayer
Expand Down Expand Up @@ -306,6 +305,10 @@ def inverse(self,
:param kwargs:
:return:
"""

# Import from torchdiffeq locally, so the package does not break if torchdiffeq not installed
from torchdiffeq import odeint

# Flatten everything to facilitate computations
batch_shape = get_batch_shape(z, self.event_shape)
batch_size = int(torch.prod(torch.as_tensor(batch_shape)))
Expand Down Expand Up @@ -399,6 +402,9 @@ def inverse(self,
:param kwargs:
:return:
"""
# Import from torchdiffeq locally, so the package does not break if torchdiffeq not installed
from torchdiffeq import odeint

# Flatten everything to facilitate computations
batch_shape = get_batch_shape(z, self.event_shape)
batch_size = int(torch.prod(torch.as_tensor(batch_shape)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ElementwiseAffine,
UMNNMaskedAutoregressive,
LRSCoupling,
LRSForwardMaskedAutoregressive
LRSForwardMaskedAutoregressive, ElementwiseShift
)
from normalizing_flows.bijections.base import BijectiveComposition
from normalizing_flows.bijections.finite.linear import ReversePermutation
Expand Down Expand Up @@ -127,27 +127,27 @@ class CouplingLRS(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseAffine(event_shape=event_shape)]
bijections = [ElementwiseShift(event_shape=event_shape)]
for _ in range(n_layers):
bijections.extend([
ReversePermutation(event_shape=event_shape),
LRSCoupling(event_shape=event_shape)
])
bijections.append(ElementwiseAffine(event_shape=event_shape))
bijections.append(ElementwiseShift(event_shape=event_shape))
super().__init__(event_shape, bijections, **kwargs)


class MaskedAutoregressiveLRS(BijectiveComposition):
def __init__(self, event_shape, n_layers: int = 2, **kwargs):
if isinstance(event_shape, int):
event_shape = (event_shape,)
bijections = [ElementwiseAffine(event_shape=event_shape)]
bijections = [ElementwiseShift(event_shape=event_shape)]
for _ in range(n_layers):
bijections.extend([
ReversePermutation(event_shape=event_shape),
LRSForwardMaskedAutoregressive(event_shape=event_shape)
])
bijections.append(ElementwiseAffine(event_shape=event_shape))
bijections.append(ElementwiseShift(event_shape=event_shape))
super().__init__(event_shape, bijections, **kwargs)


Expand All @@ -173,7 +173,7 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs):
for _ in range(n_layers):
bijections.extend([
ReversePermutation(event_shape=event_shape),
DSCoupling(event_shape=event_shape)
DSCoupling(event_shape=event_shape) # TODO specify percent of global parameters
])
bijections.append(ElementwiseAffine(event_shape=event_shape))
super().__init__(event_shape, bijections, **kwargs)
Expand Down
Loading

0 comments on commit e71c6f9

Please sign in to comment.