Skip to content

Commit

Permalink
Expand graphical coupling tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Feb 9, 2024
1 parent 5a00e38 commit df810da
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions test/test_graphical_normalizing_flow.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,75 @@
import pytest
import torch
from normalizing_flows.architectures import RealNVP
from normalizing_flows.architectures import RealNVP, NICE, CouplingRQNSF


def test_basic_2d():
@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF])
def test_basic_2d(architecture):
torch.manual_seed(0)

n_data = 100
n_dim = 2
x = torch.randn(size=(n_data, n_dim))
bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1)])
bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1)])
z, log_det_forward = bijection.forward(x)
x_reconstructed, log_det_inverse = bijection.inverse(z)

assert torch.allclose(x, x_reconstructed)
assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}"
assert torch.allclose(log_det_forward, -log_det_inverse)


def test_basic_5d():
@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF])
def test_basic_5d(architecture):
torch.manual_seed(0)

n_data = 100
n_dim = 5
x = torch.randn(size=(n_data, n_dim))
bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1), (0, 2), (0, 3), (0, 4)])
bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1), (0, 2), (0, 3), (0, 4)])
z, log_det_forward = bijection.forward(x)
x_reconstructed, log_det_inverse = bijection.inverse(z)

assert torch.allclose(x, x_reconstructed)
assert torch.allclose(x, x_reconstructed, atol=1e-4)
assert torch.allclose(log_det_forward, -log_det_inverse)


def test_basic_5d_2():
@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF])
def test_basic_5d_2(architecture):
torch.manual_seed(0)

n_data = 100
n_dim = 5
x = torch.randn(size=(n_data, n_dim))
bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1)])
bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 1)])
z, log_det_forward = bijection.forward(x)
x_reconstructed, log_det_inverse = bijection.inverse(z)

assert torch.allclose(x, x_reconstructed)
assert torch.allclose(x, x_reconstructed, atol=1e-4)
assert torch.allclose(log_det_forward, -log_det_inverse)


def test_basic_5d_3():
@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF])
def test_basic_5d_3(architecture):
torch.manual_seed(0)

n_data = 100
n_dim = 5
x = torch.randn(size=(n_data, n_dim))
bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 2), (1, 3), (1, 4)])
bijection = architecture(event_shape=(n_dim,), edge_list=[(0, 2), (1, 3), (1, 4)])
z, log_det_forward = bijection.forward(x)
x_reconstructed, log_det_inverse = bijection.inverse(z)

assert torch.allclose(x, x_reconstructed, atol=1e-5), f"{torch.linalg.norm(x - x_reconstructed)}"
assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}"
assert torch.allclose(log_det_forward, -log_det_inverse,
atol=1e-5), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}"
atol=1e-4), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}"


def test_random():
@pytest.mark.parametrize('architecture', [RealNVP, NICE, CouplingRQNSF])
def test_random(architecture):
torch.manual_seed(0)

n_data = 100
n_dim = 30
n_dim = 50
x = torch.randn(size=(n_data, n_dim))

interacting_dimensions = torch.unique(torch.randint(low=0, high=n_dim, size=(n_dim,)))
Expand All @@ -76,10 +82,10 @@ def test_random():
for t in target_dimensions:
edge_list.append((s, t))

bijection = RealNVP(event_shape=(n_dim,), edge_list=edge_list)
bijection = architecture(event_shape=(n_dim,), edge_list=edge_list)
z, log_det_forward = bijection.forward(x)
x_reconstructed, log_det_inverse = bijection.inverse(z)

assert torch.allclose(x, x_reconstructed, atol=1e-5), f"{torch.linalg.norm(x - x_reconstructed)}"
assert torch.allclose(x, x_reconstructed, atol=1e-4), f"{torch.linalg.norm(x - x_reconstructed)}"
assert torch.allclose(log_det_forward, -log_det_inverse,
atol=1e-5), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}"
atol=1e-4), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}"

0 comments on commit df810da

Please sign in to comment.