diff --git a/normalizing_flows/bijections/__init__.py b/normalizing_flows/bijections/__init__.py index 38b441d..dbaf23f 100644 --- a/normalizing_flows/bijections/__init__.py +++ b/normalizing_flows/bijections/__init__.py @@ -10,4 +10,4 @@ HouseholderSylvester, Sylvester from normalizing_flows.bijections.finite.residual.iterative import InvertibleResNet, ResFlow, \ QuasiAutoregressiveFlow, ProximalResidualFlow -from normalizing_flows.bijections.finite.linear import LowerTriangular, Orthogonal, LU, QR +from normalizing_flows.bijections.finite.linear import LowerTriangular, Orthogonal, LU, QR \ No newline at end of file diff --git a/test/test_reconstruction_bijections.py b/test/test_reconstruction_bijections.py index e4c2160..4c30e6d 100644 --- a/test/test_reconstruction_bijections.py +++ b/test/test_reconstruction_bijections.py @@ -8,6 +8,7 @@ from normalizing_flows.bijections import RealNVP, MAF, CouplingRQNSF, MaskedAutoregressiveRQNSF, ResFlow, \ InvertibleResNet, \ ElementwiseAffine, ElementwiseShift, InverseAutoregressiveRQNSF, IAF, NICE +from normalizing_flows.bijections import FFJORD from normalizing_flows.bijections.finite.base import ConditionalBijection from normalizing_flows.bijections.finite.residual.planar import Planar from normalizing_flows.bijections.finite.residual.radial import Radial @@ -120,3 +121,15 @@ def test_masked_autoregressive(bijection_class: ConditionalBijection, batch_shap def test_residual(bijection_class: ConditionalBijection, batch_shape: Tuple, event_shape: Tuple, context_shape: Tuple): bijection, x, context = setup_data(bijection_class, batch_shape, event_shape, context_shape) assert_valid_reconstruction(bijection, x, context) + + +@pytest.mark.parametrize('bijection_class', [ + FFJORD +]) +@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) +@pytest.mark.parametrize('event_shape', __test_constants['event_shape']) +@pytest.mark.parametrize('context_shape', __test_constants['context_shape']) +def test_continuous(bijection_class: ConditionalBijection, batch_shape: Tuple, event_shape: Tuple, + context_shape: Tuple): + bijection, x, context = setup_data(bijection_class, batch_shape, event_shape, context_shape) + assert_valid_reconstruction(bijection, x, context)