Skip to content

Commit

Permalink
Add test for identity bijections with maximum regularization
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Mar 21, 2024
1 parent c22d58b commit 3dab74e
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions test/test_identity_bijections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Check that when all bijection parameters are set to 0, the bijections reduce to an identity map

from normalizing_flows.bijections.finite.autoregressive.layers import (
AffineCoupling,
DSCoupling,
RQSCoupling,
InverseAffineCoupling,
LRSCoupling,
ShiftCoupling,
AffineForwardMaskedAutoregressive,
AffineInverseMaskedAutoregressive,
ElementwiseAffine,
ElementwiseRQSpline,
ElementwiseScale,
ElementwiseShift,
LinearAffineCoupling,
LinearLRSCoupling,
LinearRQSCoupling,
LinearShiftCoupling,
LRSForwardMaskedAutoregressive,
RQSForwardMaskedAutoregressive,
RQSInverseMaskedAutoregressive,
UMNNMaskedAutoregressive,

)
import torch
import pytest


@pytest.mark.parametrize(
'layer_class',
[
AffineCoupling,
DSCoupling,
RQSCoupling,
InverseAffineCoupling,
LRSCoupling,
ShiftCoupling,
AffineForwardMaskedAutoregressive,
AffineInverseMaskedAutoregressive,
ElementwiseAffine,
ElementwiseRQSpline,
ElementwiseScale,
ElementwiseShift,
LinearAffineCoupling,
LinearLRSCoupling,
LinearRQSCoupling,
LinearShiftCoupling,
LRSForwardMaskedAutoregressive,
RQSForwardMaskedAutoregressive,
RQSInverseMaskedAutoregressive,
# UMNNMaskedAutoregressive, # Inexact due to numerics
]
)
def test_basic(layer_class):
n_batch, n_dim = 2, 3

torch.manual_seed(0)
x = torch.randn(size=(n_batch, n_dim))
layer = layer_class(event_shape=torch.Size((n_dim,)))

# Set all conditioner parameters to 0
with torch.no_grad():
for p in layer.parameters():
p.data *= 0

assert torch.allclose(layer(x)[0], x, atol=1e-2)

0 comments on commit 3dab74e

Please sign in to comment.