From 6fc9e7131ef3443f59a31026a570c2a170572950 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 24 Aug 2024 20:29:38 +0200 Subject: [PATCH] Add ActNorm tests --- test/test_reconstruction_bijections.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_reconstruction_bijections.py b/test/test_reconstruction_bijections.py index 7d6dfca..ae9d5fc 100644 --- a/test/test_reconstruction_bijections.py +++ b/test/test_reconstruction_bijections.py @@ -11,7 +11,7 @@ from torchflows.bijections.finite.autoregressive.architectures import NICE, RealNVP, CouplingRQNSF, MAF, IAF, \ InverseAutoregressiveRQNSF, MaskedAutoregressiveRQNSF from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \ - LRSCoupling, LinearRQSCoupling + LRSCoupling, LinearRQSCoupling, ActNorm from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR from torchflows.bijections.finite.residual.architectures import ResFlow, InvertibleResNet, ProximalResFlow from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock @@ -129,7 +129,8 @@ def assert_valid_reconstruction_continuous(bijection: ContinuousBijection, Orthogonal, QR, ElementwiseAffine, - ElementwiseShift + ElementwiseShift, + ActNorm ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape'])