diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py index 1621732..7ac5b70 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioning/coupling_masks.py @@ -54,8 +54,16 @@ def ignored_event_size(self): class GraphicalCoupling(PartialCoupling): def __init__(self, event_shape, edge_list: List[Tuple[int, int]]): - source_mask = torch.tensor(sorted(list(set([e[0] for e in edge_list])))) - target_mask = torch.tensor(sorted(list(set([e[1] for e in edge_list])))) + if len(event_shape) != 1: + raise ValueError("GraphicalCoupling is currently only implemented for vector data") + + source_dims = torch.tensor(sorted(list(set([e[0] for e in edge_list])))) + target_dims = torch.tensor(sorted(list(set([e[1] for e in edge_list])))) + + event_size = int(torch.prod(torch.as_tensor(event_shape))) + source_mask = torch.isin(torch.arange(event_size), source_dims) + target_mask = torch.isin(torch.arange(event_size), target_dims) + super().__init__(event_shape, source_mask, target_mask) diff --git a/test/test_graphical_normalizing_flow.py b/test/test_graphical_normalizing_flow.py new file mode 100644 index 0000000..2b92025 --- /dev/null +++ b/test/test_graphical_normalizing_flow.py @@ -0,0 +1,85 @@ +import torch +from normalizing_flows.architectures import RealNVP + + +def test_basic_2d(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 2 + x = torch.randn(size=(n_data, n_dim)) + bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed) + assert torch.allclose(log_det_forward, -log_det_inverse) + + +def test_basic_5d(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 5 + x = torch.randn(size=(n_data, n_dim)) + bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1), (0, 2), (0, 3), (0, 4)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed) + assert torch.allclose(log_det_forward, -log_det_inverse) + + +def test_basic_5d_2(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 5 + x = torch.randn(size=(n_data, n_dim)) + bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 1)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed) + assert torch.allclose(log_det_forward, -log_det_inverse) + + +def test_basic_5d_3(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 5 + x = torch.randn(size=(n_data, n_dim)) + bijection = RealNVP(event_shape=(n_dim,), edge_list=[(0, 2), (1, 3), (1, 4)]) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed, atol=1e-5), f"{torch.linalg.norm(x - x_reconstructed)}" + assert torch.allclose(log_det_forward, -log_det_inverse, + atol=1e-5), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}" + + +def test_random(): + torch.manual_seed(0) + + n_data = 100 + n_dim = 30 + x = torch.randn(size=(n_data, n_dim)) + + interacting_dimensions = torch.unique(torch.randint(low=0, high=n_dim, size=(n_dim,))) + interacting_dimensions = interacting_dimensions[torch.randperm(len(interacting_dimensions))] + source_dimensions = interacting_dimensions[:len(interacting_dimensions) // 2] + target_dimensions = interacting_dimensions[len(interacting_dimensions) // 2:] + + edge_list = [] + for s in source_dimensions: + for t in target_dimensions: + edge_list.append((s, t)) + + bijection = RealNVP(event_shape=(n_dim,), edge_list=edge_list) + z, log_det_forward = bijection.forward(x) + x_reconstructed, log_det_inverse = bijection.inverse(z) + + assert torch.allclose(x, x_reconstructed, atol=1e-5), f"{torch.linalg.norm(x - x_reconstructed)}" + assert torch.allclose(log_det_forward, -log_det_inverse, + atol=1e-5), f"{torch.linalg.norm(log_det_forward + log_det_inverse)}"