From 3dab74e7faafcf6d617c0b1b7a2838f35ed82702 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 22 Mar 2024 00:41:24 +0100 Subject: [PATCH] Add test for identity bijections with maximum regularization --- test/test_identity_bijections.py | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 test/test_identity_bijections.py diff --git a/test/test_identity_bijections.py b/test/test_identity_bijections.py new file mode 100644 index 0000000..fa88e93 --- /dev/null +++ b/test/test_identity_bijections.py @@ -0,0 +1,67 @@ +# Check that when all bijection parameters are set to 0, the bijections reduce to an identity map + +from normalizing_flows.bijections.finite.autoregressive.layers import ( + AffineCoupling, + DSCoupling, + RQSCoupling, + InverseAffineCoupling, + LRSCoupling, + ShiftCoupling, + AffineForwardMaskedAutoregressive, + AffineInverseMaskedAutoregressive, + ElementwiseAffine, + ElementwiseRQSpline, + ElementwiseScale, + ElementwiseShift, + LinearAffineCoupling, + LinearLRSCoupling, + LinearRQSCoupling, + LinearShiftCoupling, + LRSForwardMaskedAutoregressive, + RQSForwardMaskedAutoregressive, + RQSInverseMaskedAutoregressive, + UMNNMaskedAutoregressive, + +) +import torch +import pytest + + +@pytest.mark.parametrize( + 'layer_class', + [ + AffineCoupling, + DSCoupling, + RQSCoupling, + InverseAffineCoupling, + LRSCoupling, + ShiftCoupling, + AffineForwardMaskedAutoregressive, + AffineInverseMaskedAutoregressive, + ElementwiseAffine, + ElementwiseRQSpline, + ElementwiseScale, + ElementwiseShift, + LinearAffineCoupling, + LinearLRSCoupling, + LinearRQSCoupling, + LinearShiftCoupling, + LRSForwardMaskedAutoregressive, + RQSForwardMaskedAutoregressive, + RQSInverseMaskedAutoregressive, + # UMNNMaskedAutoregressive, # Inexact due to numerics + ] +) +def test_basic(layer_class): + n_batch, n_dim = 2, 3 + + torch.manual_seed(0) + x = torch.randn(size=(n_batch, n_dim)) + layer = layer_class(event_shape=torch.Size((n_dim,))) + + # Set all conditioner parameters to 0 + with torch.no_grad(): + for p in layer.parameters(): + p.data *= 0 + + assert torch.allclose(layer(x)[0], x, atol=1e-2)