diff --git a/test/test_reconstruction_bijections.py b/test/test_reconstruction_bijections.py index 7d6dfca..ae9d5fc 100644 --- a/test/test_reconstruction_bijections.py +++ b/test/test_reconstruction_bijections.py @@ -11,7 +11,7 @@ from torchflows.bijections.finite.autoregressive.architectures import NICE, RealNVP, CouplingRQNSF, MAF, IAF, \ InverseAutoregressiveRQNSF, MaskedAutoregressiveRQNSF from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \ - LRSCoupling, LinearRQSCoupling + LRSCoupling, LinearRQSCoupling, ActNorm from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR from torchflows.bijections.finite.residual.architectures import ResFlow, InvertibleResNet, ProximalResFlow from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock @@ -129,7 +129,8 @@ def assert_valid_reconstruction_continuous(bijection: ContinuousBijection, Orthogonal, QR, ElementwiseAffine, - ElementwiseShift + ElementwiseShift, + ActNorm ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape'])