From 79a12013af7b7f20d4f77f5d3c5cb70fc5f91a67 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 13 Oct 2023 18:16:21 +0200 Subject: [PATCH] Add bijection reconstruction test --- test/test_reconstruction_bijections.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/test/test_reconstruction_bijections.py b/test/test_reconstruction_bijections.py index 67f7508..4aaa45e 100644 --- a/test/test_reconstruction_bijections.py +++ b/test/test_reconstruction_bijections.py @@ -196,7 +196,19 @@ def test_residual(bijection_class: Bijection, batch_shape: Tuple, event_shape: T @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) @pytest.mark.parametrize('context_shape', __test_constants['context_shape']) -def test_continuous(bijection_class: ContinuousBijection, batch_shape: Tuple, event_shape: Tuple, - context_shape: Tuple): +def test_continuous( + bijection_class: ContinuousBijection, + batch_shape: Tuple, + event_shape: Tuple, + context_shape: Tuple +): bijection, x, context = setup_data(bijection_class, batch_shape, event_shape, context_shape) assert_valid_reconstruction_continuous(bijection, x, context) + + +def test_ot_flow_instability(): + batch_shape = (5,) + event_shape = (2,) + context_shape = (2,) + bijection, x, context = setup_data(OTFlow, batch_shape, event_shape, context_shape) + assert_valid_reconstruction_continuous(bijection, x, context)