Skip to content

Commit

Permalink
Add FFJORD test
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Sep 11, 2023
1 parent 92a44d6 commit 3c94a9b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion normalizing_flows/bijections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
HouseholderSylvester, Sylvester
from normalizing_flows.bijections.finite.residual.iterative import InvertibleResNet, ResFlow, \
QuasiAutoregressiveFlow, ProximalResidualFlow
from normalizing_flows.bijections.finite.linear import LowerTriangular, Orthogonal, LU, QR
from normalizing_flows.bijections.finite.linear import LowerTriangular, Orthogonal, LU, QR
13 changes: 13 additions & 0 deletions test/test_reconstruction_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from normalizing_flows.bijections import RealNVP, MAF, CouplingRQNSF, MaskedAutoregressiveRQNSF, ResFlow, \
InvertibleResNet, \
ElementwiseAffine, ElementwiseShift, InverseAutoregressiveRQNSF, IAF, NICE
from normalizing_flows.bijections import FFJORD
from normalizing_flows.bijections.finite.base import ConditionalBijection
from normalizing_flows.bijections.finite.residual.planar import Planar
from normalizing_flows.bijections.finite.residual.radial import Radial
Expand Down Expand Up @@ -120,3 +121,15 @@ def test_masked_autoregressive(bijection_class: ConditionalBijection, batch_shap
def test_residual(bijection_class: ConditionalBijection, batch_shape: Tuple, event_shape: Tuple, context_shape: Tuple):
bijection, x, context = setup_data(bijection_class, batch_shape, event_shape, context_shape)
assert_valid_reconstruction(bijection, x, context)


@pytest.mark.parametrize('bijection_class', [
FFJORD
])
@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape'])
@pytest.mark.parametrize('event_shape', __test_constants['event_shape'])
@pytest.mark.parametrize('context_shape', __test_constants['context_shape'])
def test_continuous(bijection_class: ConditionalBijection, batch_shape: Tuple, event_shape: Tuple,
context_shape: Tuple):
bijection, x, context = setup_data(bijection_class, batch_shape, event_shape, context_shape)
assert_valid_reconstruction(bijection, x, context)

0 comments on commit 3c94a9b

Please sign in to comment.