Skip to content

Commit

Permalink
Add bijection reconstruction test
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 13, 2023
1 parent 8692cb6 commit 79a1201
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions test/test_reconstruction_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,19 @@ def test_residual(bijection_class: Bijection, batch_shape: Tuple, event_shape: T
@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape'])
@pytest.mark.parametrize('event_shape', __test_constants['event_shape'])
@pytest.mark.parametrize('context_shape', __test_constants['context_shape'])
def test_continuous(bijection_class: ContinuousBijection, batch_shape: Tuple, event_shape: Tuple,
context_shape: Tuple):
def test_continuous(
bijection_class: ContinuousBijection,
batch_shape: Tuple,
event_shape: Tuple,
context_shape: Tuple
):
bijection, x, context = setup_data(bijection_class, batch_shape, event_shape, context_shape)
assert_valid_reconstruction_continuous(bijection, x, context)


def test_ot_flow_instability():
batch_shape = (5,)
event_shape = (2,)
context_shape = (2,)
bijection, x, context = setup_data(OTFlow, batch_shape, event_shape, context_shape)
assert_valid_reconstruction_continuous(bijection, x, context)

0 comments on commit 79a1201

Please sign in to comment.