From df810dae1d729530299925bdaad2d97d7f7b1153 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Feb 2024 12:43:37 +0100 Subject: [PATCH] Expand graphical coupling tests --- test/test_graphical_normalizing_flow.py | 44 ++++++++++++++----------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/test/test_graphical_normalizing_flow.py b/test/test_graphical_normalizing_flow.py index 2b92025..0c465f8 100644 --- a/test/test_graphical_normalizing_flow.py +++ b/test/test_graphical_normalizing_flow.py @@ -1,69 +1,75 @@ +import pytest import torch -from normalizing_flows.architectures import RealNVP +from normalizing_flows.architectures import RealNVP, NICE, CouplingRQNSF -def test_basic_2d(): +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_2d(architecture): torch.manual_seed(0) n_data = 100 n_dim = 2 x = torch.randn(size=(n_data, n_dim)) - bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1)]) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1)]) z, log_det_forward = bijection.forward(x) x_reconstructed, log_det_inverse = bijection.inverse(z) - assert torch.allclose(x, x_reconstructed) + assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}" assert torch.allclose(log_det_forward, -log_det_inverse) -def test_basic_5d(): +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_5d(architecture): torch.manual_seed(0) n_data = 100 n_dim = 5 x = torch.randn(size=(n_data, n_dim)) - bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1), (0, 2), (0, 3), (0, 4)]) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1), (0, 2), (0, 3), (0, 4)]) z, log_det_forward = bijection.forward(x) x_reconstructed, log_det_inverse = bijection.inverse(z) - assert torch.allclose(x, x_reconstructed) + assert torch.allclose(x, x_reconstructed, atol=1e-4) assert torch.allclose(log_det_forward, -log_det_inverse) -def test_basic_5d_2(): +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_5d_2(architecture): torch.manual_seed(0) n_data = 100 n_dim = 5 x = torch.randn(size=(n_data, n_dim)) - bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1)]) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1)]) z, log_det_forward = bijection.forward(x) x_reconstructed, log_det_inverse = bijection.inverse(z) - assert torch.allclose(x, x_reconstructed) + assert torch.allclose(x, x_reconstructed, atol=1e-4) assert torch.allclose(log_det_forward, -log_det_inverse) -def test_basic_5d_3(): +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_basic_5d_3(architecture): torch.manual_seed(0) n_data = 100 n_dim = 5 x = torch.randn(size=(n_data, n_dim)) - bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 2), (1, 3), (1, 4)]) + bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 2), (1, 3), (1, 4)]) z, log_det_forward = bijection.forward(x) x_reconstructed, log_det_inverse = bijection.inverse(z) - assert torch.allclose(x, x_reconstructed, atol=1e-5), f"{torch.linalg.norm(x - x_reconstructed)}" + assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}" assert torch.allclose(log_det_forward, -log_det_inverse, - atol=1e-5), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" + atol=1e-4), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" -def test_random(): +@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF]) +def test_random(architecture): torch.manual_seed(0) n_data = 100 - n_dim = 30 + n_dim = 50 x = torch.randn(size=(n_data, n_dim)) interacting_dimensions = torch.unique(torch.randint(low=0, high=n_dim, size=(n_dim,))) @@ -76,10 +82,10 @@ def test_random(): for t in target_dimensions: edge_list.append((s, t)) - bijection = RealNVP(event_shape=(n_dim,), edge_list=edge_list) + bijection = architecture(event_shape=(n_dim,), edge_list=edge_list) z, log_det_forward = bijection.forward(x) x_reconstructed, log_det_inverse = bijection.inverse(z) - assert torch.allclose(x, x_reconstructed, atol=1e-5), f"{torch.linalg.norm(x - x_reconstructed)}" + assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}" assert torch.allclose(log_det_forward, -log_det_inverse, - atol=1e-5), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" + atol=1e-4), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}"