From cefda395d820a30d33177709e2033204a99892d8 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 12 Nov 2023 12:21:07 -0800 Subject: [PATCH 01/40] Add invertible 1x1 convolution transformer --- .../transformers/convolution.py | 162 ++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py b/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py new file mode 100644 index 0000000..48c6a8e --- /dev/null +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py @@ -0,0 +1,162 @@ +from typing import Union, Tuple +import torch + +from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer + + +def construct_kernels_plu( + lower_elements: torch.Tensor, + upper_elements: torch.Tensor, + log_abs_diag: torch.Tensor, + sign_diag: torch.Tensor, + permutation: torch.Tensor, + k: int, + inverse: bool = False +): + """ + :param lower_elements: (b, (k ** 2 - k) // 2) + :param upper_elements: (b, (k ** 2 - k) // 2) + :param log_abs_diag: (b, k) + :param sign_diag: (b, k) + :param permutation: (k, k) + :param k: kernel length + :param inverse: + :return: kernels with shape (b, k, k) + """ + + assert lower_elements.shape == upper_elements.shape + assert log_abs_diag.shape == sign_diag.shape + assert permutation.shape == (k, k) + assert len(log_abs_diag.shape) == 2 + assert len(lower_elements.shape) == 2 + assert lower_elements.shape[1] == (k ** 2 - k) // 2 + assert log_abs_diag.shape[1] == k + + batch_size = len(lower_elements) + + lower = torch.eye(k)[None].repeat(batch_size) + lower_row_idx, lower_col_idx = torch.tril_indices(k, k, offset=-1) + lower[:, lower_row_idx, lower_col_idx] = lower_elements + + upper = torch.einsum("ij,bj->bij", torch.eye(k), log_abs_diag.exp() * sign_diag) + upper_row_idx, upper_col_idx = torch.triu_indices(k, k, offset=1) + upper[:, upper_row_idx, upper_col_idx] = upper_elements + + if inverse: + if log_abs_diag.dtype == torch.float64: + lower_inv = torch.inverse(lower) + upper_inv = torch.inverse(upper) + else: + lower_inv = torch.inverse(lower.double()).type(log_abs_diag.dtype) + upper_inv = torch.inverse(upper.double()).type(log_abs_diag.dtype) + kernels = torch.einsum("bij,bjk,kl->bil", upper_inv, lower_inv, permutation.T) + else: + kernels = torch.einsum("ij,bjk,bkl->bil", permutation, lower, upper) + return kernels + + +class Invertible1x1Convolution(Transformer): + """ + Invertible 1x1 convolution. + + TODO permutation may be unnecessary, maybe remove. + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], kernel_length: int = 3): + if len(event_shape) != 3: + raise ValueError( + f"InvertibleConvolution transformer only supports events with shape (height, width, channels)." + ) + self.h, self.w, self.c = event_shape + if kernel_length <= 0: + raise ValueError(f"Expected kernel length to be positive, but got {kernel_length}") + self.k = kernel_length + self.sign_diag = torch.sign(torch.randn(self.k)) + self.permutation = torch.eye(self.k)[torch.randperm(self.k)] + super().__init__(event_shape) + + @property + def n_parameters(self) -> int: + return self.k ** 2 + + @property + def default_parameters(self) -> torch.Tensor: + # Kernel matrix is identity (p=0,u=0,log_diag=0). + # Some diagonal elements are negated according to self.sign_diag. + # The matrix is then permuted. + return torch.zeros(self.n_parameters) + + def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + We parametrize K = PLU. The parameters h contain elements of L and U. + There are k ** 2 such elements. + + x.shape == (batch_size, c, h, w) + h.shape == (batch_size, k * k) + + We expect each kernel to be invertible. + """ + if len(x.shape) != 4: + raise ValueError(f"Expected x to have shape (batch_size, channels, height, width), but got {x.shape}") + if len(h.shape) != 2: + raise ValueError(f"Expected h.shape to be of length 2, but got {h.shape} with length {len(h.shape)}") + if h.shape[1] != self.k * self.k: + raise ValueError( + f"Expected h to have shape (batch_size, kernel_height * kernel_width) = (batch_size, {self.k * self.k})," + f" but got {h.shape}" + ) + + n_p_elements = (self.k ** 2 - self.k) // 2 + p_elements = h[..., :n_p_elements] + u_elements = h[..., n_p_elements:n_p_elements * 2] + diag_elements = h[..., n_p_elements * 2:] + + kernels = construct_kernels_plu( + p_elements, + u_elements, + diag_elements, + self.sign_diag, + self.permutation, + self.k, + inverse=False + ) + log_det = self.h * self.w * torch.log(torch.abs(torch.linalg.det(kernels))) # (*batch_shape) + + # Reshape images to (1, b, c, h, w), reshape kernels to (b, 1, k, k) + # This lets us convolve each image with its own kernel + z = torch.conv2d(x[None], kernels[:, None], groups=self.c)[0] + + return z, log_det + + def inverse(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if len(x.shape) != 4: + raise ValueError(f"Expected x to have shape (batch_size, height, width, channels), but got {x.shape}") + if len(h.shape) != 2: + raise ValueError(f"Expected h.shape to be of length 2, but got {h.shape} with length {len(h.shape)}") + if h.shape[1] != self.k * self.k: + raise ValueError( + f"Expected h to have shape (batch_size, kernel_height * kernel_width) = (batch_size, {self.k * self.k})," + f" but got {h.shape}" + ) + + n_p_elements = (self.k ** 2 - self.k) // 2 + p_elements = h[..., :n_p_elements] + u_elements = h[..., n_p_elements:n_p_elements * 2] + diag_elements = h[..., n_p_elements * 2:] + + kernels = construct_kernels_plu( + p_elements, + u_elements, + diag_elements, + self.sign_diag, + self.permutation, + self.k, + inverse=True + ) + log_det = -self.h * self.w * torch.log(torch.abs(torch.linalg.det(kernels))) # (*batch_shape) + + # Reshape images to (1, b, c, h, w), reshape kernels to (b, 1, k, k) + # This lets us convolve each image with its own kernel + z = torch.conv2d(x[None], kernels[:, None], groups=self.c)[0] + + return z, log_det From acba4675ef519f49754efe3056b433fa8f6ee8ab Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 12 Nov 2023 12:33:07 -0800 Subject: [PATCH 02/40] Add invertible 1x1 convolution transformer test --- .../transformers/convolution.py | 7 +++++- test/constants.py | 1 + test/test_reconstruction_transformers.py | 25 +++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py b/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py index 48c6a8e..5a30e78 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py @@ -67,12 +67,13 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], kernel_lengt raise ValueError( f"InvertibleConvolution transformer only supports events with shape (height, width, channels)." ) - self.h, self.w, self.c = event_shape + self.c, self.h, self.w = event_shape if kernel_length <= 0: raise ValueError(f"Expected kernel length to be positive, but got {kernel_length}") self.k = kernel_length self.sign_diag = torch.sign(torch.randn(self.k)) self.permutation = torch.eye(self.k)[torch.randperm(self.k)] + self.const = 1000 super().__init__(event_shape) @property @@ -106,6 +107,8 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch f" but got {h.shape}" ) + h = self.default_parameters + h / self.const + n_p_elements = (self.k ** 2 - self.k) // 2 p_elements = h[..., :n_p_elements] u_elements = h[..., n_p_elements:n_p_elements * 2] @@ -139,6 +142,8 @@ def inverse(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch f" but got {h.shape}" ) + h = self.default_parameters + h / self.const + n_p_elements = (self.k ** 2 - self.k) // 2 p_elements = h[..., :n_p_elements] u_elements = h[..., n_p_elements:n_p_elements * 2] diff --git a/test/constants.py b/test/constants.py index a4a1b96..2664eb4 100644 --- a/test/constants.py +++ b/test/constants.py @@ -1,6 +1,7 @@ __test_constants = { 'batch_shape': [(1,), (2,), (5,), (2, 4), (5, 2, 3)], 'event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], + 'image_shape': [(1, 20, 20), (1, 10, 20), (3, 20, 20), (3, 10, 20), (3, 200, 200)], 'context_shape': [None, (2,), (3,), (2, 4), (5,)], 'input_event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], 'output_event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], diff --git a/test/test_reconstruction_transformers.py b/test/test_reconstruction_transformers.py index 3cdaf43..2aa890b 100644 --- a/test/test_reconstruction_transformers.py +++ b/test/test_reconstruction_transformers.py @@ -9,6 +9,7 @@ LinearRational as LinearRationalSpline from normalizing_flows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import \ RationalQuadratic as RationalQuadraticSpline +from normalizing_flows.bijections.finite.autoregressive.transformers.convolution import Invertible1x1Convolution from normalizing_flows.bijections.finite.autoregressive.transformers.spline.cubic import Cubic as CubicSpline from normalizing_flows.bijections.finite.autoregressive.transformers.spline.basis import Basis as BasisSpline from normalizing_flows.bijections.finite.autoregressive.transformers.affine import Affine, Scale, Shift @@ -112,3 +113,27 @@ def test_combination_basic(transformer_class: Transformer, batch_shape: Tuple, e def test_combination_vector_to_vector(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) assert_valid_reconstruction(transformer, x, h) + + +@pytest.mark.parametrize('kernel_length', [1, 2, 3, 5]) +@pytest.mark.parametrize('image_shape', __test_constants['image_shape']) +def test_convolution(image_shape: Tuple, kernel_length): + torch.manual_seed(0) + batch_size = 10 + images = torch.randn(size=(batch_size, *image_shape)) + parameters = torch.randn(size=(batch_size, kernel_length ** 2)) + transformer = Invertible1x1Convolution(image_shape, kernel_length=kernel_length) + latent_images, log_det_forward = transformer.forward(images, parameters) + reconstructed_images, log_det_inverse = transformer.inverse(latent_images, parameters) + + assert latent_images.shape == images.shape + assert reconstructed_images.shape == images.shape + assert torch.isfinite(latent_images).all() + assert torch.isfinite(reconstructed_images).all() + assert torch.allclose(latent_images, reconstructed_images) + + assert log_det_forward.shape == (batch_size,) + assert log_det_inverse.shape == (batch_size,) + assert torch.isfinite(log_det_forward).all() + assert torch.isfinite(log_det_inverse).all() + assert torch.allclose(log_det_forward, -log_det_inverse) From cd9bd02a570ba0aa047ef8e8c188ebf264600892 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 12 Nov 2023 15:10:40 -0800 Subject: [PATCH 03/40] Towards fixing invertible convolution --- .../transformers/convolution.py | 65 +++++++++---------- test/constants.py | 2 +- test/test_reconstruction_transformers.py | 25 +++---- 3 files changed, 44 insertions(+), 48 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py b/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py index 5a30e78..ff9ed4b 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py @@ -17,7 +17,7 @@ def construct_kernels_plu( :param lower_elements: (b, (k ** 2 - k) // 2) :param upper_elements: (b, (k ** 2 - k) // 2) :param log_abs_diag: (b, k) - :param sign_diag: (b, k) + :param sign_diag: (k,) :param permutation: (k, k) :param k: kernel length :param inverse: @@ -25,7 +25,7 @@ def construct_kernels_plu( """ assert lower_elements.shape == upper_elements.shape - assert log_abs_diag.shape == sign_diag.shape + assert log_abs_diag.shape[1] == sign_diag.shape[0] assert permutation.shape == (k, k) assert len(log_abs_diag.shape) == 2 assert len(lower_elements.shape) == 2 @@ -34,7 +34,7 @@ def construct_kernels_plu( batch_size = len(lower_elements) - lower = torch.eye(k)[None].repeat(batch_size) + lower = torch.eye(k)[None].repeat(batch_size, 1, 1) lower_row_idx, lower_col_idx = torch.tril_indices(k, k, offset=-1) lower[:, lower_row_idx, lower_col_idx] = lower_elements @@ -62,23 +62,20 @@ class Invertible1x1Convolution(Transformer): TODO permutation may be unnecessary, maybe remove. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], kernel_length: int = 3): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): if len(event_shape) != 3: raise ValueError( f"InvertibleConvolution transformer only supports events with shape (height, width, channels)." ) - self.c, self.h, self.w = event_shape - if kernel_length <= 0: - raise ValueError(f"Expected kernel length to be positive, but got {kernel_length}") - self.k = kernel_length - self.sign_diag = torch.sign(torch.randn(self.k)) - self.permutation = torch.eye(self.k)[torch.randperm(self.k)] + self.n_channels, self.h, self.w = event_shape + self.sign_diag = torch.sign(torch.randn(self.n_channels)) + self.permutation = torch.eye(self.n_channels)[torch.randperm(self.n_channels)] self.const = 1000 super().__init__(event_shape) @property def n_parameters(self) -> int: - return self.k ** 2 + return self.n_channels ** 2 @property def default_parameters(self) -> torch.Tensor: @@ -101,67 +98,65 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch raise ValueError(f"Expected x to have shape (batch_size, channels, height, width), but got {x.shape}") if len(h.shape) != 2: raise ValueError(f"Expected h.shape to be of length 2, but got {h.shape} with length {len(h.shape)}") - if h.shape[1] != self.k * self.k: + if h.shape[1] != self.n_channels * self.n_channels: raise ValueError( - f"Expected h to have shape (batch_size, kernel_height * kernel_width) = (batch_size, {self.k * self.k})," + f"Expected h to have shape (batch_size, kernel_height * kernel_width) = (batch_size, {self.n_channels * self.n_channels})," f" but got {h.shape}" ) h = self.default_parameters + h / self.const - n_p_elements = (self.k ** 2 - self.k) // 2 + n_p_elements = (self.n_channels ** 2 - self.n_channels) // 2 p_elements = h[..., :n_p_elements] u_elements = h[..., n_p_elements:n_p_elements * 2] - diag_elements = h[..., n_p_elements * 2:] + log_diag_elements = h[..., n_p_elements * 2:] kernels = construct_kernels_plu( p_elements, u_elements, - diag_elements, + log_diag_elements, self.sign_diag, self.permutation, - self.k, + self.n_channels, inverse=False - ) - log_det = self.h * self.w * torch.log(torch.abs(torch.linalg.det(kernels))) # (*batch_shape) + ) # (b, k, k) + log_det = self.h * self.w * torch.sum(log_diag_elements, dim=-1) # (*batch_shape) - # Reshape images to (1, b, c, h, w), reshape kernels to (b, 1, k, k) - # This lets us convolve each image with its own kernel - z = torch.conv2d(x[None], kernels[:, None], groups=self.c)[0] + z = torch.zeros_like(x) + for i in range(len(x)): + z[i] = torch.conv2d(x[i], kernels[i][:, :, None, None], groups=1, stride=1, padding="same") return z, log_det def inverse(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if len(x.shape) != 4: - raise ValueError(f"Expected x to have shape (batch_size, height, width, channels), but got {x.shape}") + raise ValueError(f"Expected x to have shape (batch_size, channels, height, width), but got {x.shape}") if len(h.shape) != 2: raise ValueError(f"Expected h.shape to be of length 2, but got {h.shape} with length {len(h.shape)}") - if h.shape[1] != self.k * self.k: + if h.shape[1] != self.n_channels * self.n_channels: raise ValueError( - f"Expected h to have shape (batch_size, kernel_height * kernel_width) = (batch_size, {self.k * self.k})," + f"Expected h to have shape (batch_size, kernel_height * kernel_width) = (batch_size, {self.n_channels * self.n_channels})," f" but got {h.shape}" ) h = self.default_parameters + h / self.const - n_p_elements = (self.k ** 2 - self.k) // 2 + n_p_elements = (self.n_channels ** 2 - self.n_channels) // 2 p_elements = h[..., :n_p_elements] u_elements = h[..., n_p_elements:n_p_elements * 2] - diag_elements = h[..., n_p_elements * 2:] + log_diag_elements = h[..., n_p_elements * 2:] kernels = construct_kernels_plu( p_elements, u_elements, - diag_elements, + log_diag_elements, self.sign_diag, self.permutation, - self.k, + self.n_channels, inverse=True ) - log_det = -self.h * self.w * torch.log(torch.abs(torch.linalg.det(kernels))) # (*batch_shape) - - # Reshape images to (1, b, c, h, w), reshape kernels to (b, 1, k, k) - # This lets us convolve each image with its own kernel - z = torch.conv2d(x[None], kernels[:, None], groups=self.c)[0] - + log_det = -self.h * self.w * torch.sum(log_diag_elements, dim=-1) # (*batch_shape) + z = torch.zeros_like(x) + for i in range(len(x)): + z[i] = torch.conv2d(x[i], kernels[i][:, :, None, None], groups=1, stride=1, padding="same") return z, log_det diff --git a/test/constants.py b/test/constants.py index 2664eb4..181886f 100644 --- a/test/constants.py +++ b/test/constants.py @@ -1,7 +1,7 @@ __test_constants = { 'batch_shape': [(1,), (2,), (5,), (2, 4), (5, 2, 3)], 'event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], - 'image_shape': [(1, 20, 20), (1, 10, 20), (3, 20, 20), (3, 10, 20), (3, 200, 200)], + 'image_shape': [(3, 4, 4), (3, 20, 20), (3, 10, 20), (3, 200, 200), (1, 20, 20), (1, 10, 20), ], 'context_shape': [None, (2,), (3,), (2, 4), (5,)], 'input_event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], 'output_event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], diff --git a/test/test_reconstruction_transformers.py b/test/test_reconstruction_transformers.py index 2aa890b..b026627 100644 --- a/test/test_reconstruction_transformers.py +++ b/test/test_reconstruction_transformers.py @@ -115,25 +115,26 @@ def test_combination_vector_to_vector(transformer_class: Transformer, batch_shap assert_valid_reconstruction(transformer, x, h) -@pytest.mark.parametrize('kernel_length', [1, 2, 3, 5]) +@pytest.mark.parametrize("batch_size", [2, 3, 5, 7, 1]) @pytest.mark.parametrize('image_shape', __test_constants['image_shape']) -def test_convolution(image_shape: Tuple, kernel_length): +def test_convolution(batch_size: int, image_shape: Tuple): torch.manual_seed(0) - batch_size = 10 + n_channels = image_shape[0] images = torch.randn(size=(batch_size, *image_shape)) - parameters = torch.randn(size=(batch_size, kernel_length ** 2)) - transformer = Invertible1x1Convolution(image_shape, kernel_length=kernel_length) + parameters = torch.randn(size=(batch_size, n_channels ** 2)) + transformer = Invertible1x1Convolution(image_shape) latent_images, log_det_forward = transformer.forward(images, parameters) reconstructed_images, log_det_inverse = transformer.inverse(latent_images, parameters) - assert latent_images.shape == images.shape - assert reconstructed_images.shape == images.shape - assert torch.isfinite(latent_images).all() - assert torch.isfinite(reconstructed_images).all() - assert torch.allclose(latent_images, reconstructed_images) - assert log_det_forward.shape == (batch_size,) assert log_det_inverse.shape == (batch_size,) assert torch.isfinite(log_det_forward).all() assert torch.isfinite(log_det_inverse).all() - assert torch.allclose(log_det_forward, -log_det_inverse) + assert torch.allclose(log_det_forward, -log_det_inverse, atol=1e-3) + + assert latent_images.shape == images.shape + assert reconstructed_images.shape == images.shape + assert torch.isfinite(latent_images).all() + assert torch.isfinite(reconstructed_images).all() + rec_err = torch.max(torch.abs(latent_images - reconstructed_images)) + assert torch.allclose(latent_images, reconstructed_images, atol=1e-2), f"{rec_err = }" From 3e1e29ef4cfb276ddae6f7f8b426f50fb76a81ea Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 13 Nov 2023 11:17:33 -0800 Subject: [PATCH 04/40] Import torchdiffeq locally to reduce dependence --- normalizing_flows/bijections/continuous/base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/normalizing_flows/bijections/continuous/base.py b/normalizing_flows/bijections/continuous/base.py index 12caf68..6751ee5 100644 --- a/normalizing_flows/bijections/continuous/base.py +++ b/normalizing_flows/bijections/continuous/base.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -from torchdiffeq import odeint from normalizing_flows.bijections.base import Bijection from normalizing_flows.bijections.continuous.layers import DiffEqLayer @@ -306,6 +305,10 @@ def inverse(self, :param kwargs: :return: """ + + # Import from torchdiffeq locally, so the package does not break if torchdiffeq not installed + from torchdiffeq import odeint + # Flatten everything to facilitate computations batch_shape = get_batch_shape(z, self.event_shape) batch_size = int(torch.prod(torch.as_tensor(batch_shape))) @@ -399,6 +402,9 @@ def inverse(self, :param kwargs: :return: """ + # Import from torchdiffeq locally, so the package does not break if torchdiffeq not installed + from torchdiffeq import odeint + # Flatten everything to facilitate computations batch_shape = get_batch_shape(z, self.event_shape) batch_size = int(torch.prod(torch.as_tensor(batch_shape))) From f4805e947049594319ccd48938fb9a9f9bb15268 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 13 Nov 2023 14:56:58 -0800 Subject: [PATCH 05/40] Replace elementwise affine with elementwise shift in CouplingLRS flows for improved numerical stability --- .../bijections/finite/autoregressive/architectures.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index 0e1331a..e242064 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -11,7 +11,7 @@ ElementwiseAffine, UMNNMaskedAutoregressive, LRSCoupling, - LRSForwardMaskedAutoregressive + LRSForwardMaskedAutoregressive, ElementwiseShift ) from normalizing_flows.bijections.base import BijectiveComposition from normalizing_flows.bijections.finite.linear import ReversePermutation @@ -127,13 +127,13 @@ class CouplingLRS(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] + bijections = [ElementwiseShift(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), LRSCoupling(event_shape=event_shape) ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections.append(ElementwiseShift(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) From 02f88bab1ebaca813a5b9fc7b11b6ab060a5486f Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 13 Nov 2023 16:00:22 -0800 Subject: [PATCH 06/40] Use elementwise shift instead of elementwise affine in default LRS architecture --- .../bijections/finite/autoregressive/architectures.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index e242064..b31c72d 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -141,13 +141,13 @@ class MaskedAutoregressiveLRS(BijectiveComposition): def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = [ElementwiseAffine(event_shape=event_shape)] + bijections = [ElementwiseShift(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), LRSForwardMaskedAutoregressive(event_shape=event_shape) ]) - bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections.append(ElementwiseShift(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) From 869dda40c99914841b3aa7154ca4fe92a24ac54b Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Mon, 13 Nov 2023 16:01:30 -0800 Subject: [PATCH 07/40] Use small epsilon to handle numerics in LRS, move 2 into log as square for better stability --- .../transformers/spline/linear_rational.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py index 300dd9f..6e94ed1 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py @@ -20,6 +20,7 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], boundary: fl self.min_bin_height = 1e-3 self.min_d = 1e-5 self.const = math.log(math.exp(1 - self.min_d) - 1) # to ensure identity initialization + self.eps = 1e-7 # Epsilon for numerical stability when computing forward/inverse @property def n_parameters(self) -> int: @@ -120,7 +121,7 @@ def forward_1d(self, x, h): ) log_det_phi_lt_lambda = ( torch.log(lambda_k * w_k * w_m * (y_m - y_k)) - - 2 * torch.log(w_k * (lambda_k - phi) + w_m * phi) + - torch.log((w_k * (lambda_k - phi) + w_m * phi) ** 2 + self.eps) - torch.log(x_kp1 - x_k) ) @@ -130,7 +131,7 @@ def forward_1d(self, x, h): ) log_det_phi_gt_lambda = ( torch.log((1 - lambda_k) * w_m * w_kp1 * (y_kp1 - y_m)) - - 2 * torch.log(w_m * (1 - phi) + w_kp1 * (phi - lambda_k)) + - torch.log((w_m * (1 - phi) + w_kp1 * (phi - lambda_k)) ** 2 + self.eps) - torch.log(x_kp1 - x_k) ) @@ -166,7 +167,7 @@ def inverse_1d(self, z, h): ) * (x_kp1 - x_k) + x_k log_det_y_lt_ym = ( torch.log(lambda_k * w_k * w_m * (y_m - y_k)) - - torch.log((w_k * (y_k - z) + w_m * (z - y_m)) ** 2) + - torch.log((w_k * (y_k - z) + w_m * (z - y_m)) ** 2 + self.eps) + torch.log(x_kp1 - x_k) ) @@ -176,7 +177,7 @@ def inverse_1d(self, z, h): ) * (x_kp1 - x_k) + x_k log_det_y_gt_ym = ( torch.log((1 - lambda_k) * w_m * w_kp1 * (y_kp1 - y_m)) - - 2 * torch.log(w_kp1 * (y_kp1 - z) + w_m * (z - y_m)) + - torch.log((w_kp1 * (y_kp1 - z) + w_m * (z - y_m)) ** 2 + self.eps) + torch.log(x_kp1 - x_k) ) From fdd4390d39834108762ec268d315fd3added4de4 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 14 Nov 2023 13:04:13 -0800 Subject: [PATCH 08/40] Increase minimum bin size in LRS for improved stability --- .../autoregressive/transformers/spline/linear_rational.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py index 6e94ed1..2208535 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py @@ -16,8 +16,8 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], boundary: fl max_output=boundary, **kwargs ) - self.min_bin_width = 1e-3 - self.min_bin_height = 1e-3 + self.min_bin_width = 1e-2 + self.min_bin_height = 1e-2 self.min_d = 1e-5 self.const = math.log(math.exp(1 - self.min_d) - 1) # to ensure identity initialization self.eps = 1e-7 # Epsilon for numerical stability when computing forward/inverse From fc490fcc342cd8e02f73abc4212f39046d77f420 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 14 Nov 2023 13:10:33 -0800 Subject: [PATCH 09/40] Reduce LRS epsilon for lower reconstruction error --- .../autoregressive/transformers/spline/linear_rational.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py index 2208535..2c23edb 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py @@ -20,7 +20,7 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], boundary: fl self.min_bin_height = 1e-2 self.min_d = 1e-5 self.const = math.log(math.exp(1 - self.min_d) - 1) # to ensure identity initialization - self.eps = 1e-7 # Epsilon for numerical stability when computing forward/inverse + self.eps = 5e-10 # Epsilon for numerical stability when computing forward/inverse @property def n_parameters(self) -> int: From 9e40e62e2a553e5d655f4414725aefe1eac63127 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 14 Nov 2023 15:43:07 -0800 Subject: [PATCH 10/40] Add option to learn global parameters in ConditionerTransform --- .../autoregressive/conditioner_transforms.py | 119 ++++++++++++------ .../finite/autoregressive/layers.py | 24 ++-- test/test_conditioner_transforms.py | 2 +- 3 files changed, 95 insertions(+), 50 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 1bf15e5..0ac8240 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -9,12 +9,39 @@ class ConditionerTransform(nn.Module): + """ + Module which predicts transformer parameters for the transformation of a tensor y using an input tensor x and + possibly a corresponding context tensor c. + + In other words, a conditioner transform f predicts theta = f(x, c) to be used in transformer g with z = g(y; theta). + The transformation g is performed elementwise on tensor y. + Since g transforms each element of y with a parameter tensor of shape (n_transformer_parameters,), + the shape of theta is (*y.shape, n_transformer_parameters). + """ + def __init__(self, input_event_shape, context_shape, output_event_shape, - n_predicted_parameters: int, - context_combiner: ContextCombiner = None): + n_transformer_parameters: int, + context_combiner: ContextCombiner = None, + percent_globally_learned_parameters: float = 0.0, + initial_global_parameter_value: float = None): + """ + :param input_event_shape: shape of conditioner input tensor x. + :param context_shape: shape of conditioner context tensor c. + :param output_event_shape: shape of transformer input tensor y. + :param n_transformer_parameters: number of parameters required to transform a single element of y. + :param context_combiner: ContextCombiner class which defines how to combine x and c to predict theta. + :param percent_globally_learned_parameters: fraction of all parameters in theta that should be learned directly. + A value of 0 means the conditioner predicts n_transformer_parameters parameters based on x and c. + A value of 1 means the conditioner predicts no parameters based on x and c, but outputs globally learned theta. + A value of alpha means the conditioner outputs alpha * n_transformer_parameters parameters globally and + predicts the rest. In this case, the predicted parameters are the last alpha * n_transformer_parameters + elements in theta. + :param initial_global_parameter_value: the initial value for the entire globally learned part of theta. If None, + the global part of theta is initialized to samples from the standard normal distribution. + """ super().__init__() if context_shape is None: context_combiner = Bypass(input_event_shape) @@ -28,12 +55,38 @@ def __init__(self, self.context_shape = context_shape self.n_input_event_dims = self.context_combiner.n_output_dims self.n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) - self.n_predicted_parameters = n_predicted_parameters + self.n_transformer_parameters = n_transformer_parameters + self.n_globally_learned_parameters = int(n_transformer_parameters * percent_globally_learned_parameters) + self.n_predicted_parameters = self.n_transformer_parameters - self.n_globally_learned_parameters + + if initial_global_parameter_value is None: + initial_global_theta = torch.randn(size=(*output_event_shape, self.n_globally_learned_parameters)) + else: + initial_global_theta = torch.full( + size=(*output_event_shape, self.n_globally_learned_parameters), + fill_value=initial_global_parameter_value + ) + self.global_theta = nn.Parameter(initial_global_theta) def forward(self, x: torch.Tensor, context: torch.Tensor = None): # x.shape = (*batch_shape, *input_event_shape) # context.shape = (*batch_shape, *context_shape) - # output.shape = (*batch_shape, *output_event_shape, n_predicted_parameters) + # output.shape = (*batch_shape, *output_event_shape, n_transformer_parameters) + if self.n_globally_learned_parameters == 0: + return self.predict_theta(x, context) + else: + n_batch_dims = len(x.shape) - len(self.output_event_shape) + n_event_dims = len(self.output_event_shape) + batch_shape = x.shape[:n_batch_dims] + batch_global_theta = pad_leading_dims(self.global_theta, n_batch_dims).repeat( + *batch_shape, *([1] * n_event_dims), 1 + ) + if self.n_globally_learned_parameters == self.n_transformer_parameters: + return batch_global_theta + else: + return torch.cat([batch_global_theta, self.predict_theta(x, context)], dim=-1) + + def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): raise NotImplementedError @@ -43,19 +96,10 @@ def __init__(self, output_event_shape, n_parameters: int, fill_value: float = No input_event_shape=None, context_shape=None, output_event_shape=output_event_shape, - n_predicted_parameters=n_parameters + n_transformer_parameters=n_parameters, + initial_global_parameter_value=fill_value, + percent_globally_learned_parameters=1.0 ) - if fill_value is None: - initial_theta = torch.randn(size=(*self.output_event_shape, n_parameters,)) - else: - initial_theta = torch.full(size=(*self.output_event_shape, n_parameters), fill_value=fill_value) - self.theta = nn.Parameter(initial_theta) - - def forward(self, x: torch.Tensor, context: torch.Tensor = None): - n_batch_dims = len(x.shape) - len(self.output_event_shape) - n_event_dims = len(self.output_event_shape) - batch_shape = x.shape[:n_batch_dims] - return pad_leading_dims(self.theta, n_batch_dims).repeat(*batch_shape, *([1] * n_event_dims), 1) class MADE(ConditionerTransform): @@ -70,7 +114,7 @@ def forward(self, x): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, - n_predicted_parameters: int, + n_transformer_parameters: int, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2): @@ -78,7 +122,7 @@ def __init__(self, input_event_shape=input_event_shape, context_shape=context_shape, output_event_shape=output_event_shape, - n_predicted_parameters=n_predicted_parameters + n_transformer_parameters=n_transformer_parameters ) if n_hidden is None: @@ -103,10 +147,10 @@ def __init__(self, layers.extend([ self.MaskedLinear( masks[-1].shape[1], - masks[-1].shape[0] * n_predicted_parameters, - torch.repeat_interleave(masks[-1], n_predicted_parameters, dim=0) + masks[-1].shape[0] * self.n_predicted_parameters, + torch.repeat_interleave(masks[-1], self.n_predicted_parameters, dim=0) ), - nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, n_predicted_parameters)) + nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters)) ]) self.sequential = nn.Sequential(*layers) @@ -123,23 +167,23 @@ def create_masks(n_layers, ms): masks.append(torch.as_tensor(xx >= yy, dtype=torch.float)) return masks - def forward(self, x: torch.Tensor, context: torch.Tensor = None): + def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): out = self.sequential(self.context_combiner(x, context)) batch_shape = get_batch_shape(x, self.input_event_shape) return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) class LinearMADE(MADE): - def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, n_predicted_parameters: int, + def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, n_transformer_parameters: int, **kwargs): - super().__init__(input_event_shape, output_event_shape, n_predicted_parameters, n_layers=1, **kwargs) + super().__init__(input_event_shape, output_event_shape, n_transformer_parameters, n_layers=1, **kwargs) class FeedForward(ConditionerTransform): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, - n_predicted_parameters: int, + n_transformer_parameters: int, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2): @@ -147,7 +191,7 @@ def __init__(self, input_event_shape=input_event_shape, context_shape=context_shape, output_event_shape=output_event_shape, - n_predicted_parameters=n_predicted_parameters + n_transformer_parameters=n_transformer_parameters ) if n_hidden is None: @@ -161,20 +205,20 @@ def __init__(self, # Check the one layer special case if n_layers == 1: - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_predicted_parameters)) + layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_transformer_parameters)) elif n_layers > 1: layers.extend([nn.Linear(self.n_input_event_dims, n_hidden), nn.Tanh()]) for _ in range(n_layers - 2): layers.extend([nn.Linear(n_hidden, n_hidden), nn.Tanh()]) - layers.append(nn.Linear(n_hidden, self.n_output_event_dims * n_predicted_parameters)) + layers.append(nn.Linear(n_hidden, self.n_output_event_dims * self.n_predicted_parameters)) else: raise ValueError # Reshape the output - layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, n_predicted_parameters))) + layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters))) self.sequential = nn.Sequential(*layers) - def forward(self, x: torch.Tensor, context: torch.Tensor = None): + def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): out = self.sequential(self.context_combiner(x, context)) batch_shape = get_batch_shape(x, self.input_event_shape) return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) @@ -197,10 +241,11 @@ def forward(self, x): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, - n_predicted_parameters: int, + n_transformer_parameters: int, context_shape: torch.Size = None, - n_layers: int = 2): - super().__init__(input_event_shape, context_shape, output_event_shape, n_predicted_parameters) + n_layers: int = 2, + **kwargs): + super().__init__(input_event_shape, context_shape, output_event_shape, n_transformer_parameters, **kwargs) # If context given, concatenate it to transform input if context_shape is not None: @@ -210,20 +255,20 @@ def __init__(self, # Check the one layer special case if n_layers == 1: - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_predicted_parameters)) + layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * self.n_predicted_parameters)) elif n_layers > 1: layers.extend([self.ResidualLinear(self.n_input_event_dims, self.n_input_event_dims), nn.Tanh()]) for _ in range(n_layers - 2): layers.extend([self.ResidualLinear(self.n_input_event_dims, self.n_input_event_dims), nn.Tanh()]) - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_predicted_parameters)) + layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * self.n_predicted_parameters)) else: raise ValueError # Reshape the output - layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, n_predicted_parameters))) + layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters))) self.sequential = nn.Sequential(*layers) - def forward(self, x: torch.Tensor, context: torch.Tensor = None): + def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): out = self.sequential(self.context_combiner(x, context)) batch_shape = get_batch_shape(x, self.input_event_shape) return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index e9e203f..6ab4c79 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -52,7 +52,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -71,7 +71,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -88,7 +88,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -107,7 +107,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -125,7 +125,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -145,7 +145,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -181,7 +181,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -203,7 +203,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -224,7 +224,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -245,7 +245,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -268,7 +268,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -295,7 +295,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_predicted_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) diff --git a/test/test_conditioner_transforms.py b/test/test_conditioner_transforms.py index e68514d..c74f0b9 100644 --- a/test/test_conditioner_transforms.py +++ b/test/test_conditioner_transforms.py @@ -24,7 +24,7 @@ def test_shape(transform_class, batch_shape, input_event_shape, output_event_sha transform = transform_class( input_event_shape=input_event_shape, output_event_shape=output_event_shape, - n_predicted_parameters=n_predicted_parameters + n_transformer_parameters=n_predicted_parameters ) out = transform(x) assert out.shape == (*batch_shape, *output_event_shape, n_predicted_parameters) From c1cebf01154865c6061843f72748c558c041a784 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Wed, 15 Nov 2023 08:45:10 -0800 Subject: [PATCH 11/40] Add keyword arguments to conditioner transforms --- .../autoregressive/conditioner_transforms.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 0ac8240..987304e 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -117,12 +117,14 @@ def __init__(self, n_transformer_parameters: int, context_shape: torch.Size = None, n_hidden: int = None, - n_layers: int = 2): + n_layers: int = 2, + **kwargs): super().__init__( input_event_shape=input_event_shape, context_shape=context_shape, output_event_shape=output_event_shape, - n_transformer_parameters=n_transformer_parameters + n_transformer_parameters=n_transformer_parameters, + **kwargs ) if n_hidden is None: @@ -176,7 +178,13 @@ def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): class LinearMADE(MADE): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, n_transformer_parameters: int, **kwargs): - super().__init__(input_event_shape, output_event_shape, n_transformer_parameters, n_layers=1, **kwargs) + super().__init__( + input_event_shape, + output_event_shape, + n_transformer_parameters, + n_layers=1, + **kwargs + ) class FeedForward(ConditionerTransform): @@ -186,12 +194,14 @@ def __init__(self, n_transformer_parameters: int, context_shape: torch.Size = None, n_hidden: int = None, - n_layers: int = 2): + n_layers: int = 2, + **kwargs): super().__init__( input_event_shape=input_event_shape, context_shape=context_shape, output_event_shape=output_event_shape, - n_transformer_parameters=n_transformer_parameters + n_transformer_parameters=n_transformer_parameters, + **kwargs ) if n_hidden is None: @@ -245,7 +255,13 @@ def __init__(self, context_shape: torch.Size = None, n_layers: int = 2, **kwargs): - super().__init__(input_event_shape, context_shape, output_event_shape, n_transformer_parameters, **kwargs) + super().__init__( + input_event_shape, + context_shape, + output_event_shape, + n_transformer_parameters, + **kwargs + ) # If context given, concatenate it to transform input if context_shape is not None: From 453c13f44eb3646b31f0dc422121742ba3b68e49 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Wed, 15 Nov 2023 09:21:38 -0800 Subject: [PATCH 12/40] Rename attributes in conditioner_transforms.py, set 0.8 global parameters in CouplingDSF --- .../finite/autoregressive/architectures.py | 4 ++-- .../autoregressive/conditioner_transforms.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index b31c72d..35f56a4 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -166,14 +166,14 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class CouplingDSF(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, percent_global_parameters: float = 0.8, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - DSCoupling(event_shape=event_shape) + DSCoupling(event_shape=event_shape, percent_global_parameters=percent_global_parameters) ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 987304e..a3f00d6 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -25,7 +25,7 @@ def __init__(self, output_event_shape, n_transformer_parameters: int, context_combiner: ContextCombiner = None, - percent_globally_learned_parameters: float = 0.0, + percent_global_parameters: float = 0.0, initial_global_parameter_value: float = None): """ :param input_event_shape: shape of conditioner input tensor x. @@ -33,7 +33,7 @@ def __init__(self, :param output_event_shape: shape of transformer input tensor y. :param n_transformer_parameters: number of parameters required to transform a single element of y. :param context_combiner: ContextCombiner class which defines how to combine x and c to predict theta. - :param percent_globally_learned_parameters: fraction of all parameters in theta that should be learned directly. + :param percent_global_parameters: percent of all parameters in theta to be learned independent of x and c. A value of 0 means the conditioner predicts n_transformer_parameters parameters based on x and c. A value of 1 means the conditioner predicts no parameters based on x and c, but outputs globally learned theta. A value of alpha means the conditioner outputs alpha * n_transformer_parameters parameters globally and @@ -56,14 +56,14 @@ def __init__(self, self.n_input_event_dims = self.context_combiner.n_output_dims self.n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) self.n_transformer_parameters = n_transformer_parameters - self.n_globally_learned_parameters = int(n_transformer_parameters * percent_globally_learned_parameters) - self.n_predicted_parameters = self.n_transformer_parameters - self.n_globally_learned_parameters + self.n_global_parameters = int(n_transformer_parameters * percent_global_parameters) + self.n_predicted_parameters = self.n_transformer_parameters - self.n_global_parameters if initial_global_parameter_value is None: - initial_global_theta = torch.randn(size=(*output_event_shape, self.n_globally_learned_parameters)) + initial_global_theta = torch.randn(size=(*output_event_shape, self.n_global_parameters)) else: initial_global_theta = torch.full( - size=(*output_event_shape, self.n_globally_learned_parameters), + size=(*output_event_shape, self.n_global_parameters), fill_value=initial_global_parameter_value ) self.global_theta = nn.Parameter(initial_global_theta) @@ -72,7 +72,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None): # x.shape = (*batch_shape, *input_event_shape) # context.shape = (*batch_shape, *context_shape) # output.shape = (*batch_shape, *output_event_shape, n_transformer_parameters) - if self.n_globally_learned_parameters == 0: + if self.n_global_parameters == 0: return self.predict_theta(x, context) else: n_batch_dims = len(x.shape) - len(self.output_event_shape) @@ -81,7 +81,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None): batch_global_theta = pad_leading_dims(self.global_theta, n_batch_dims).repeat( *batch_shape, *([1] * n_event_dims), 1 ) - if self.n_globally_learned_parameters == self.n_transformer_parameters: + if self.n_global_parameters == self.n_transformer_parameters: return batch_global_theta else: return torch.cat([batch_global_theta, self.predict_theta(x, context)], dim=-1) @@ -98,7 +98,7 @@ def __init__(self, output_event_shape, n_parameters: int, fill_value: float = No output_event_shape=output_event_shape, n_transformer_parameters=n_parameters, initial_global_parameter_value=fill_value, - percent_globally_learned_parameters=1.0 + percent_global_parameters=1.0 ) From dbd45ce3c49618a135878ec6108d6c183954fc22 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 16 Nov 2023 09:43:10 -0800 Subject: [PATCH 13/40] Conditioner transforms predict a flattened parameter tensor for each input tensor in the batch --- .../autoregressive/conditioner_transforms.py | 93 +++++++++++-------- .../finite/autoregressive/layers.py | 13 +-- 2 files changed, 54 insertions(+), 52 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index a3f00d6..378f845 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -22,7 +22,6 @@ class ConditionerTransform(nn.Module): def __init__(self, input_event_shape, context_shape, - output_event_shape, n_transformer_parameters: int, context_combiner: ContextCombiner = None, percent_global_parameters: float = 0.0, @@ -30,7 +29,6 @@ def __init__(self, """ :param input_event_shape: shape of conditioner input tensor x. :param context_shape: shape of conditioner context tensor c. - :param output_event_shape: shape of transformer input tensor y. :param n_transformer_parameters: number of parameters required to transform a single element of y. :param context_combiner: ContextCombiner class which defines how to combine x and c to predict theta. :param percent_global_parameters: percent of all parameters in theta to be learned independent of x and c. @@ -51,51 +49,62 @@ def __init__(self, # The conditioner transform receives as input the context combiner output self.input_event_shape = input_event_shape - self.output_event_shape = output_event_shape self.context_shape = context_shape self.n_input_event_dims = self.context_combiner.n_output_dims - self.n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) self.n_transformer_parameters = n_transformer_parameters self.n_global_parameters = int(n_transformer_parameters * percent_global_parameters) self.n_predicted_parameters = self.n_transformer_parameters - self.n_global_parameters if initial_global_parameter_value is None: - initial_global_theta = torch.randn(size=(*output_event_shape, self.n_global_parameters)) + initial_global_theta = torch.randn(size=(self.n_global_parameters,)) else: initial_global_theta = torch.full( - size=(*output_event_shape, self.n_global_parameters), + size=(self.n_global_parameters,), fill_value=initial_global_parameter_value ) self.global_theta = nn.Parameter(initial_global_theta) def forward(self, x: torch.Tensor, context: torch.Tensor = None): - # x.shape = (*batch_shape, *input_event_shape) - # context.shape = (*batch_shape, *context_shape) - # output.shape = (*batch_shape, *output_event_shape, n_transformer_parameters) + """ + Compute parameters theta for each input tensor. + This includes globally learned parameters and parameters which are predicted based on x and context. + + :param x: batch of input tensors with x.shape = (*batch_shape, *self.input_event_shape). + :param context: batch of context tensors with context.shape = (*batch_shape, *self.context_shape). + :return: batch of parameter tensors theta with theta.shape = (*batch_shape, self.n_transformer_parameters). + """ if self.n_global_parameters == 0: return self.predict_theta(x, context) else: - n_batch_dims = len(x.shape) - len(self.output_event_shape) - n_event_dims = len(self.output_event_shape) - batch_shape = x.shape[:n_batch_dims] - batch_global_theta = pad_leading_dims(self.global_theta, n_batch_dims).repeat( - *batch_shape, *([1] * n_event_dims), 1 - ) + batch_shape = get_batch_shape(x, self.input_event_shape) + n_batch_dims = len(batch_shape) + batch_global_theta = pad_leading_dims(self.global_theta, n_batch_dims).repeat(*batch_shape, 1) if self.n_global_parameters == self.n_transformer_parameters: return batch_global_theta else: return torch.cat([batch_global_theta, self.predict_theta(x, context)], dim=-1) def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): + """ + Predict parameters theta for each input tensor. + Note: this method does not set any global parameters, but instead only predicts parameters from x and context. + + :param x: batch of input tensors with x.shape = (*batch_shape, *self.input_event_shape). + :param context: batch of context tensors with context.shape = (*batch_shape, *self.context_shape). + :return: batch of parameter tensors theta with theta.shape = (*batch_shape, self.n_predicted_parameters). + """ raise NotImplementedError class Constant(ConditionerTransform): - def __init__(self, output_event_shape, n_parameters: int, fill_value: float = None): + """ + Constant conditioner transform, which only uses global parameters theta and no local parameters. + """ + + def __init__(self, input_event_shape, n_parameters: int, fill_value: float = None): super().__init__( - input_event_shape=None, + input_event_shape=input_event_shape, context_shape=None, - output_event_shape=output_event_shape, n_transformer_parameters=n_parameters, initial_global_parameter_value=fill_value, percent_global_parameters=1.0 @@ -103,6 +112,10 @@ def __init__(self, output_event_shape, n_parameters: int, fill_value: float = No class MADE(ConditionerTransform): + """ + Masked autoencoder for distribution estimation. + """ + class MaskedLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, mask: torch.Tensor): super().__init__(in_features=in_features, out_features=out_features) @@ -113,7 +126,7 @@ def forward(self, x): def __init__(self, input_event_shape: torch.Size, - output_event_shape: torch.Size, + n_transformer_parameters: int, context_shape: torch.Size = None, n_hidden: int = None, @@ -122,7 +135,6 @@ def __init__(self, super().__init__( input_event_shape=input_event_shape, context_shape=context_shape, - output_event_shape=output_event_shape, n_transformer_parameters=n_transformer_parameters, **kwargs ) @@ -152,7 +164,7 @@ def __init__(self, masks[-1].shape[0] * self.n_predicted_parameters, torch.repeat_interleave(masks[-1], self.n_predicted_parameters, dim=0) ), - nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters)) + nn.Unflatten(dim=-1, unflattened_size=(self.n_predicted_parameters,)) ]) self.sequential = nn.Sequential(*layers) @@ -170,17 +182,18 @@ def create_masks(n_layers, ms): return masks def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): - out = self.sequential(self.context_combiner(x, context)) - batch_shape = get_batch_shape(x, self.input_event_shape) - return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) + return self.sequential(self.context_combiner(x, context)) class LinearMADE(MADE): - def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, n_transformer_parameters: int, + """ + Masked autoencoder for distribution estimation with a single layer. + """ + + def __init__(self, input_event_shape: torch.Size, n_transformer_parameters: int, **kwargs): super().__init__( input_event_shape, - output_event_shape, n_transformer_parameters, n_layers=1, **kwargs @@ -188,9 +201,12 @@ def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size class FeedForward(ConditionerTransform): + """ + Feed-forward neural network conditioner transform. + """ + def __init__(self, input_event_shape: torch.Size, - output_event_shape: torch.Size, n_transformer_parameters: int, context_shape: torch.Size = None, n_hidden: int = None, @@ -199,7 +215,6 @@ def __init__(self, super().__init__( input_event_shape=input_event_shape, context_shape=context_shape, - output_event_shape=output_event_shape, n_transformer_parameters=n_transformer_parameters, **kwargs ) @@ -224,22 +239,26 @@ def __init__(self, else: raise ValueError - # Reshape the output - layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters))) self.sequential = nn.Sequential(*layers) def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): - out = self.sequential(self.context_combiner(x, context)) - batch_shape = get_batch_shape(x, self.input_event_shape) - return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) + return self.sequential(self.context_combiner(x, context)) class Linear(FeedForward): + """ + Linear conditioner transform with the map: theta = a * combiner(x, context) + b. + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, n_layers=1) class ResidualFeedForward(ConditionerTransform): + """ + Residual feed-forward neural network conditioner transform. + """ + class ResidualLinear(nn.Module): def __init__(self, n_in, n_out): super().__init__() @@ -250,7 +269,6 @@ def forward(self, x): def __init__(self, input_event_shape: torch.Size, - output_event_shape: torch.Size, n_transformer_parameters: int, context_shape: torch.Size = None, n_layers: int = 2, @@ -258,7 +276,6 @@ def __init__(self, super().__init__( input_event_shape, context_shape, - output_event_shape, n_transformer_parameters, **kwargs ) @@ -280,11 +297,7 @@ def __init__(self, else: raise ValueError - # Reshape the output - layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters))) self.sequential = nn.Sequential(*layers) def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): - out = self.sequential(self.context_combiner(x, context)) - batch_shape = get_batch_shape(x, self.input_event_shape) - return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) + return self.sequential(self.context_combiner(x, context)) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 6ab4c79..61df510 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -51,7 +51,6 @@ def __init__(self, conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -70,7 +69,6 @@ def __init__(self, conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -87,7 +85,6 @@ def __init__(self, conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -106,7 +103,6 @@ def __init__(self, conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -124,7 +120,6 @@ def __init__(self, conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -144,7 +139,6 @@ def __init__(self, # Each component has parameter order [a_unc, b, w_unc] conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -180,7 +174,6 @@ def __init__(self, transformer = Affine(event_shape=event_shape) conditioner_transform = MADE( input_event_shape=event_shape, - output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -202,7 +195,6 @@ def __init__(self, transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, - output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -214,6 +206,7 @@ def __init__(self, conditioner_transform=conditioner_transform ) + class LRSForwardMaskedAutoregressive(ForwardMaskedAutoregressiveBijection): def __init__(self, event_shape: torch.Size, @@ -223,7 +216,6 @@ def __init__(self, transformer = LinearRational(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, - output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -244,7 +236,6 @@ def __init__(self, transformer = invert(Affine(event_shape=event_shape)) conditioner_transform = MADE( input_event_shape=event_shape, - output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -267,7 +258,6 @@ def __init__(self, transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, - output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -294,7 +284,6 @@ def __init__(self, ) conditioner_transform = MADE( input_event_shape=event_shape, - output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs From 48a14f561a009d56aee53ae02e2fb1ea983feedb Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 16 Nov 2023 10:11:48 -0800 Subject: [PATCH 14/40] Add intermediate parameter reshape method to transformers, modify affine transformer parameter count --- .../autoregressive/transformers/affine.py | 89 ++++--------------- .../autoregressive/transformers/base.py | 74 +++++++++++++-- 2 files changed, 83 insertions(+), 80 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py b/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py index 4e0b2ba..747c253 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py @@ -24,15 +24,12 @@ def __init__(self, event_shape: torch.Size, min_scale: float = 1e-3): @property def n_parameters(self) -> int: - return 2 + return 2 * self.n_dim - @property - def default_parameters(self) -> torch.Tensor: - default_u_alpha = torch.zeros(size=(1,)) - default_u_beta = torch.zeros(size=(1,)) - return torch.cat([default_u_alpha, default_u_beta], dim=0) + def unflatten_conditioner_parameters(self, h: torch.Tensor): + return torch.unflatten(h, dim=-1, sizes=(*self.event_shape, 2)) - def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward_base(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: u_alpha = h[..., 0] alpha = torch.exp(self.identity_unconstrained_alpha + u_alpha / self.const) + self.m log_alpha = torch.log(alpha) @@ -43,7 +40,7 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch log_det = sum_except_batch(log_alpha, self.event_shape) return alpha * x + beta, log_det - def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def inverse_base(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: u_alpha = h[..., 0] alpha = torch.exp(self.identity_unconstrained_alpha + u_alpha / self.const) + self.m log_alpha = torch.log(alpha) @@ -55,80 +52,24 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch return (z - beta) / alpha, log_det -class Affine2(Transformer): - """ - Affine transformer with near-identity initialization. - - Computes z = alpha * x + beta, where alpha > 0 and -inf < beta < inf. - Alpha and beta have the same shape as x, i.e. the computation is performed elementwise. - - In this implementation, we compute alpha and beta such that the initial map is near identity. - We also use a minimum permitted scale m, 0 < m <= alpha, for numerical stability - This means setting alpha = 1 + d(u_alpha) where -1 + m < d(u_alpha) < inf. - - We can verify that the following construction for function d is suitable: - * g(u) = u / c + log(-log(m)); c>=1, 0 0 is enough, but c >= 1 desirably reduces input magnitude - * f(u) = exp(g(u)) + log(m) - * d(u) = exp(f(u)) - 1 - - A change in u implies a change in log(log(d)), we may need a bigger step size when optimizing parameters of - overarching bijections that use this transformer. - """ - - def __init__(self, event_shape: torch.Size, min_scale: float = 1e-3, **kwargs): - super().__init__(event_shape=event_shape) - assert 0 < min_scale < 1 - self.m = min_scale - self.log_m = math.log(self.m) - self.log_neg_log_m = math.log(-self.log_m) - self.c = 100.0 - - def compute_scale_and_shift(self, h): - u_alpha = h[..., 0] - g_alpha = u_alpha / self.c + self.log_neg_log_m - f_alpha = torch.exp(g_alpha) + self.log_m - d_alpha = torch.exp(f_alpha) - 1 - # d_alpha = self.m ** (1 - torch.exp(u_alpha / self.c)) - 1 # Rewritten - # d_alpha = self.m * self.m ** (-torch.exp(u_alpha / self.c)) - 1 # Rewritten again - - alpha = 1 + d_alpha - - u_beta = h[..., 1] - d_beta = u_beta / self.c - beta = 0 + d_beta - - return alpha, beta - - def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - alpha, beta = self.compute_scale_and_shift(h) - log_det = sum_except_batch(torch.log(alpha), self.event_shape) - return alpha * x + beta, log_det - - def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - alpha, beta = self.compute_scale_and_shift(h) - log_det = -sum_except_batch(torch.log(alpha), self.event_shape) - return (z - beta) / alpha, log_det - - class Shift(Transformer): def __init__(self, event_shape: torch.Size): super().__init__(event_shape=event_shape) @property def n_parameters(self) -> int: - return 1 + return 1 * self.n_dim - @property - def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(1,)) + def unflatten_conditioner_parameters(self, h: torch.Tensor): + return torch.unflatten(h, dim=-1, sizes=(*self.event_shape, 1)) - def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward_base(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: beta = h[..., 0] batch_shape = get_batch_shape(x, self.event_shape) log_det = torch.zeros(batch_shape, device=x.device) return x + beta, log_det - def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def inverse_base(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: beta = h[..., 0] batch_shape = get_batch_shape(z, self.event_shape) log_det = torch.zeros(batch_shape, device=z.device) @@ -151,12 +92,12 @@ def __init__(self, event_shape: torch.Size, min_scale: float = 1e-3): @property def n_parameters(self) -> int: - return 1 + return 1 * self.n_dim - def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(1,)) + def unflatten_conditioner_parameters(self, h: torch.Tensor): + return torch.unflatten(h, dim=-1, sizes=(*self.event_shape, 1)) - def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward_base(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: u_alpha = h[..., 0] alpha = torch.exp(self.u_alpha_1 + u_alpha / self.const) + self.m log_alpha = torch.log(alpha) @@ -164,7 +105,7 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch log_det = sum_except_batch(log_alpha, self.event_shape) return alpha * x, log_det - def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def inverse_base(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: u_alpha = h[..., 0] alpha = torch.exp(self.u_alpha_1 + u_alpha / self.const) + self.m log_alpha = torch.log(alpha) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/base.py index b23ecdf..da2d377 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/base.py @@ -6,27 +6,89 @@ class Transformer(Bijection): + """ + Base transformer class. + + A transformer receives as input a batch of tensors x with x.shape = (*batch_shape, *event_shape) and a + corresponding batch of parameter tensors h with h.shape = (*batch_shape, self.n_parameters). It outputs transformed + tensors z with z.shape = (*batch_shape, *event_shape). + Given parameters h, a transformer is bijective. Transformers apply bijections of the same kind to a batch of inputs, + but use different parameters for each input. + + When implementing new transformers, consider the self.unflatten_conditioner_parameters method, which is used to + optionally reshape transformer parameters into a suitable shape. + """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): super().__init__(event_shape=event_shape) + def forward_base(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply forward bijection to a batch of inputs x, parameterizing each bijection with the corresponding parameter + tensor in h. + + :param x: input tensor with x.shape = (*batch_shape, *event_shape). + :param h: parameter tensor with h.shape = (*batch_shape, *parameter_shape). + """ + raise NotImplementedError + def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply forward bijection to a batch of inputs x, parameterizing each bijection with the corresponding parameter + tensor in h. + + :param x: input tensor with x.shape = (*batch_shape, *event_shape). + :param h: parameter tensor with h.shape = (*batch_shape, self.n_parameters). + """ + return self.forward_base(x, self.unflatten_conditioner_parameters(h)) + + def inverse_base(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply inverse bijection to a batch of inputs x, parameterizing each bijection with the corresponding parameter + tensor in h. + + :param x: input tensor with x.shape = (*batch_shape, *event_shape). + :param h: parameter tensor with h.shape = (*batch_shape, *parameter_shape). + """ raise NotImplementedError def inverse(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError + """ + Apply inverse bijection to a batch of inputs x, parameterizing each bijection with the corresponding parameter + tensor in h. + + :param x: input tensor with x.shape = (*batch_shape, *event_shape). + :param h: parameter tensor with h.shape = (*batch_shape, self.n_parameters). + """ + return self.inverse_base(x, self.unflatten_conditioner_parameters(h)) + + def unflatten_conditioner_parameters(self, h: torch.Tensor): + """ + Reshapes parameter tensors as predicted by the conditioner. + The new shape facilitates operations in the transformer and facilitates transformer operations. + If this method is not overwritten, the default parameter tensor shape is kept. + + :param h: batch of parameter tensors for each input event with shape (*batch_shape, self.n_parameters). + :return: batch of parameter tensors for each input event with shape (*batch_shape, *new_shape) + """ + return h @property def n_parameters(self) -> int: """ - Number of parameters that parametrize this transformer. Example: rational quadratic splines require 3*b-1 where - b is the number of bins. An affine transformation requires 2 (typically corresponding to the unconstrained scale - and shift). + Number of parameters that parametrize this transformer. + + Examples: + * Rational quadratic splines require (3 * b - 1) * d where b is the number of bins and d is the + dimensionality, equal to the product of all dimensions of the transformer input tensor. + * An affine transformation requires 2 * d (typically corresponding to the unconstrained scale and shift). """ raise NotImplementedError @property def default_parameters(self) -> torch.Tensor: """ - Set of parameters which ensures an identity transformation. + Set of parameters which yields a close-to-identity transformation. + These are set to 0 by default. """ - raise NotImplementedError + return torch.zeros(size=(self.n_parameters,)) From f5c593e0e06834bdadc951f8be4496eb5f5d2286 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 16 Nov 2023 10:37:06 -0800 Subject: [PATCH 15/40] Add forward_base and inverse_base methods that receive reshaped parameter tensors as input --- .../autoregressive/conditioner_transforms.py | 17 +++++++++++------ .../autoregressive/conditioners/coupling.py | 2 ++ .../finite/autoregressive/layers_base.py | 6 ++++-- .../autoregressive/transformers/spline/base.py | 4 ++-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 378f845..7398cd0 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -126,7 +126,7 @@ def forward(self, x): def __init__(self, input_event_shape: torch.Size, - + output_event_shape: torch.Size, n_transformer_parameters: int, context_shape: torch.Size = None, n_hidden: int = None, @@ -138,6 +138,7 @@ def __init__(self, n_transformer_parameters=n_transformer_parameters, **kwargs ) + self.n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) if n_hidden is None: n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) @@ -190,10 +191,14 @@ class LinearMADE(MADE): Masked autoencoder for distribution estimation with a single layer. """ - def __init__(self, input_event_shape: torch.Size, n_transformer_parameters: int, + def __init__(self, + input_event_shape: torch.Size, + output_event_shape: torch.Size, + n_transformer_parameters: int, **kwargs): super().__init__( input_event_shape, + output_event_shape, n_transformer_parameters, n_layers=1, **kwargs @@ -230,12 +235,12 @@ def __init__(self, # Check the one layer special case if n_layers == 1: - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_transformer_parameters)) + layers.append(nn.Linear(self.n_input_event_dims, n_transformer_parameters)) elif n_layers > 1: layers.extend([nn.Linear(self.n_input_event_dims, n_hidden), nn.Tanh()]) for _ in range(n_layers - 2): layers.extend([nn.Linear(n_hidden, n_hidden), nn.Tanh()]) - layers.append(nn.Linear(n_hidden, self.n_output_event_dims * self.n_predicted_parameters)) + layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) else: raise ValueError @@ -288,12 +293,12 @@ def __init__(self, # Check the one layer special case if n_layers == 1: - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * self.n_predicted_parameters)) + layers.append(nn.Linear(self.n_input_event_dims, self.n_predicted_parameters)) elif n_layers > 1: layers.extend([self.ResidualLinear(self.n_input_event_dims, self.n_input_event_dims), nn.Tanh()]) for _ in range(n_layers - 2): layers.extend([self.ResidualLinear(self.n_input_event_dims, self.n_input_event_dims), nn.Tanh()]) - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * self.n_predicted_parameters)) + layers.append(nn.Linear(self.n_input_event_dims, self.n_predicted_parameters)) else: raise ValueError diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling.py b/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling.py index 2805573..65df293 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling.py @@ -12,6 +12,8 @@ def __init__(self, event_shape: torch.Size, constants: torch.Tensor): """ Coupling conditioner. + Warning: at the moment, this only works for elementwise conditioners (which predict a set of parameters for each + element in the event). Note: Always treats the first n_dim // 2 dimensions as constant. Shuffling is handled in Permutation bijections. diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index beddbb6..e4369bb 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -40,13 +40,15 @@ def __init__(self, conditioner: Coupling, transformer: Transformer, conditioner_ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: z = x.clone() h, mask = self.conditioner(x, self.conditioner_transform, context, return_mask=True) - z[..., ~mask], log_det = self.transformer.forward(x[..., ~mask], h[..., ~mask, :]) + z[..., ~mask], log_det = self.transformer.forward_base(x[..., ~mask], h[..., ~mask, :]) + # TODO make this work with self.transformer.forward return z, log_det def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: x = z.clone() h, mask = self.conditioner(z, self.conditioner_transform, context, return_mask=True) - x[..., ~mask], log_det = self.transformer.inverse(z[..., ~mask], h[..., ~mask, :]) + x[..., ~mask], log_det = self.transformer.inverse_base(z[..., ~mask], h[..., ~mask, :]) + # TODO make this work with self.transformer.inverse return x, log_det diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py index 2caa2dc..f4dc654 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py @@ -50,7 +50,7 @@ def compute_knots(self, u_x, u_y): def forward_1d(self, x, h): raise NotImplementedError - def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward_base(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: z = torch.clone(x) # Remain the same out of bounds log_det = torch.zeros_like(z) # y = x means gradient = 1 or log gradient = 0 out of bounds mask = self.forward_inputs_inside_bounds_mask(x) @@ -62,7 +62,7 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch def inverse_1d(self, z, h): raise NotImplementedError - def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def inverse_base(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x = torch.clone(z) # Remain the same out of bounds log_det = torch.zeros_like(x) # y = x means gradient = 1 or log gradient = 0 out of bounds mask = self.inverse_inputs_inside_bounds_mask(z) From 323f52e5d282ed171f1c989b685b22198f0223e3 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 16 Nov 2023 13:57:23 -0800 Subject: [PATCH 16/40] Split CouplingConditioner functionality into CouplingBijection and CouplingMask. Many changes. --- .../autoregressive/conditioners/coupling.py | 66 ---------------- .../conditioners/coupling_masks.py | 44 +++++++++++ .../finite/autoregressive/layers.py | 56 ++++++++------ .../finite/autoregressive/layers_base.py | 77 ++++++++++++++----- .../transformers/spline/base.py | 2 +- .../transformers/spline/cubic.py | 7 +- .../transformers/spline/linear.py | 7 +- .../transformers/spline/linear_rational.py | 12 +-- .../transformers/spline/rational_quadratic.py | 10 +-- 9 files changed, 146 insertions(+), 135 deletions(-) delete mode 100644 normalizing_flows/bijections/finite/autoregressive/conditioners/coupling.py create mode 100644 normalizing_flows/bijections/finite/autoregressive/conditioners/coupling_masks.py diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling.py b/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling.py deleted file mode 100644 index 65df293..0000000 --- a/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Union, Tuple - -import torch - -from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner -from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform -from normalizing_flows.utils import get_batch_shape - - -class Coupling(Conditioner): - def __init__(self, event_shape: torch.Size, constants: torch.Tensor): - """ - Coupling conditioner. - - Warning: at the moment, this only works for elementwise conditioners (which predict a set of parameters for each - element in the event). - - Note: Always treats the first n_dim // 2 dimensions as constant. - Shuffling is handled in Permutation bijections. - - :param constants: - """ - super().__init__() - self.event_shape = event_shape - - # TODO add support for other kinds of masks - n_total_dims = int(torch.prod(torch.tensor(event_shape))) - self.n_constant_dims = n_total_dims // 2 - self.n_changed_dims = n_total_dims - self.n_constant_dims - - self.constant_mask = torch.less(torch.arange(n_total_dims).view(*event_shape), self.n_constant_dims) - self.register_buffer('constants', constants) # Takes care of torch devices - - @property - @torch.no_grad() - def input_shape(self): - return (int(torch.sum(self.constant_mask)),) - - @property - @torch.no_grad() - def output_shape(self): - return (int(torch.sum(~self.constant_mask)),) - - def forward(self, - x: torch.Tensor, - transform: ConditionerTransform, - context: torch.Tensor = None, - return_mask: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - # Predict transformer parameters for output dimensions - batch_shape = get_batch_shape(x, self.event_shape) - x_const = x.view(*batch_shape, *self.event_shape)[..., self.constant_mask] - tmp = transform(x_const, context=context) - n_parameters = tmp.shape[-1] - - # Create full parameter tensor - h = torch.empty(size=(*batch_shape, *self.event_shape, n_parameters), dtype=x.dtype, device=x.device) - - # Fill the parameter tensor with predicted values - h[..., ~self.constant_mask, :] = tmp - h[..., self.constant_mask, :] = self.constants - - if return_mask: - # Return the parameters for the to-be-transformed partition and the partition mask itself - return h, self.constant_mask - - return h diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling_masks.py b/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling_masks.py new file mode 100644 index 0000000..04ddc8f --- /dev/null +++ b/normalizing_flows/bijections/finite/autoregressive/conditioners/coupling_masks.py @@ -0,0 +1,44 @@ +import torch + + +class CouplingMask: + """ + Base object which holds coupling partition mask information. + """ + + def __init__(self, event_shape): + self.event_shape = event_shape + self.event_size = int(torch.prod(torch.as_tensor(self.event_shape))) + + @property + def mask(self): + raise NotImplementedError + + @property + def constant_event_size(self): + raise NotImplementedError + + @property + def transformed_event_size(self): + raise NotImplementedError + + +class HalfSplit(CouplingMask): + def __init__(self, event_shape): + super().__init__(event_shape) + self.event_partition_mask = torch.less( + torch.arange(self.event_size).view(*self.event_shape), + self.constant_event_size + ) + + @property + def constant_event_size(self): + return self.event_size // 2 + + @property + def transformed_event_size(self): + return self.event_size - self.constant_event_size + + @property + def mask(self): + return self.event_partition_mask diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 61df510..6caa572 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -1,7 +1,7 @@ import torch from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import MADE, FeedForward -from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling import Coupling +from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import HalfSplit from normalizing_flows.bijections.finite.autoregressive.conditioners.masked import MaskedAutoregressive from normalizing_flows.bijections.finite.autoregressive.layers_base import ForwardMaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection @@ -47,15 +47,15 @@ def __init__(self, **kwargs): if event_shape == (1,): raise ValueError - transformer = Affine(event_shape=event_shape) - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = Affine(event_shape=(coupling_mask.transformed_event_size,)) conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class InverseAffineCoupling(CouplingBijection): @@ -65,15 +65,15 @@ def __init__(self, **kwargs): if event_shape == (1,): raise ValueError - transformer = Affine(event_shape=event_shape).invert() - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = Affine(event_shape=(coupling_mask.transformed_event_size,)).invert() conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class ShiftCoupling(CouplingBijection): @@ -81,15 +81,15 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, **kwargs): - transformer = Shift(event_shape=event_shape) - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = Shift(event_shape=(coupling_mask.transformed_event_size,)) conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class LRSCoupling(CouplingBijection): @@ -99,15 +99,15 @@ def __init__(self, n_bins: int = 8, **kwargs): assert n_bins >= 1 - transformer = LinearRational(event_shape=event_shape, n_bins=n_bins) - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = LinearRational(event_shape=(coupling_mask.transformed_event_size,), n_bins=n_bins) conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class RQSCoupling(CouplingBijection): @@ -116,15 +116,15 @@ def __init__(self, context_shape: torch.Size = None, n_bins: int = 8, **kwargs): - transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = RationalQuadratic(event_shape=(coupling_mask.transformed_event_size,), n_bins=n_bins) conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class DSCoupling(CouplingBijection): @@ -133,17 +133,17 @@ def __init__(self, context_shape: torch.Size = None, n_hidden_layers: int = 2, **kwargs): - transformer = DeepSigmoid(event_shape=event_shape, n_hidden_layers=n_hidden_layers) - conditioner = Coupling(constants=transformer.default_parameters, event_shape=event_shape) + coupling_mask = HalfSplit(event_shape) + transformer = DeepSigmoid(event_shape=(coupling_mask.transformed_event_size,), n_hidden_layers=n_hidden_layers) # Parameter order: [c1, c2, c3, c4, ..., ck] for all components # Each component has parameter order [a_unc, b, w_unc] conditioner_transform = FeedForward( - input_event_shape=conditioner.input_shape, + input_event_shape=torch.Size((coupling_mask.constant_event_size,)), n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs ) - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer, coupling_mask, conditioner_transform) class LinearAffineCoupling(AffineCoupling): @@ -174,6 +174,7 @@ def __init__(self, transformer = Affine(event_shape=event_shape) conditioner_transform = MADE( input_event_shape=event_shape, + output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -195,6 +196,7 @@ def __init__(self, transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, + output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -216,6 +218,7 @@ def __init__(self, transformer = LinearRational(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, + output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -236,6 +239,7 @@ def __init__(self, transformer = invert(Affine(event_shape=event_shape)) conditioner_transform = MADE( input_event_shape=event_shape, + output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -258,6 +262,7 @@ def __init__(self, transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, + output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs @@ -284,6 +289,7 @@ def __init__(self, ) conditioner_transform = MADE( input_event_shape=event_shape, + output_event_shape=event_shape, n_transformer_parameters=transformer.n_parameters, context_shape=context_shape, **kwargs diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index e4369bb..0e3312c 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -1,18 +1,23 @@ -from typing import Tuple +from typing import Tuple, Optional import torch from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner, NullConditioner from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant -from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling import Coupling +from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import CouplingMask from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer from normalizing_flows.bijections.base import Bijection from normalizing_flows.utils import flatten_event, unflatten_event, get_batch_shape class AutoregressiveBijection(Bijection): - def __init__(self, conditioner: Conditioner, transformer: Transformer, conditioner_transform: ConditionerTransform): - super().__init__(event_shape=transformer.event_shape) + def __init__(self, + event_shape, + conditioner: Optional[Conditioner], + transformer: Transformer, + conditioner_transform: ConditionerTransform, + **kwargs): + super().__init__(event_shape=event_shape) self.conditioner = conditioner self.conditioner_transform = conditioner_transform self.transformer = transformer @@ -29,32 +34,65 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class CouplingBijection(AutoregressiveBijection): - def __init__(self, conditioner: Coupling, transformer: Transformer, conditioner_transform: ConditionerTransform, - **kwargs): - super().__init__(conditioner, transformer, conditioner_transform, **kwargs) + """ + Base coupling bijection object. + + A coupling bijection is defined using a transformer, conditioner transform, and always a coupling conditioner. - # We need to change the transformer event shape because it will no longer accept full-shaped events, but only - # a flattened selection of event dimensions. - self.transformer.event_shape = torch.Size((self.conditioner.n_changed_dims,)) + The coupling conditioner receives as input an event tensor x. + It then partitions an input event tensor x into a constant part x_A and a modifiable part x_B. + For x_A, the conditioner outputs a set of parameters which is always the same. + For x_B, the conditioner outputs a set of parameters which are predicted from x_A. + + Coupling conditioners differ in the partitioning method. By default, the event is flattened; the first half is x_A + and the second half is x_B. When using this in a normalizing flow, permutation layers can shuffle event dimensions. + + For improved performance, this implementation does not use a standalone coupling conditioner. It instead implements + a method to partition x into x_A and x_B and then predict parameters for x_B. + """ + + def __init__(self, + transformer: Transformer, + coupling_mask: CouplingMask, + conditioner_transform: ConditionerTransform, + **kwargs): + super().__init__(coupling_mask.event_shape, None, transformer, conditioner_transform, **kwargs) + self.coupling_mask = coupling_mask + + assert conditioner_transform.input_event_shape == (coupling_mask.constant_event_size,) + assert transformer.event_shape == (self.coupling_mask.transformed_event_size,) + + def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tensor): + """ + Partition tensor x and compute transformer parameters. + + :param x: input tensor with x.shape = (*batch_shape, *event_shape) to be partitioned into x_A and x_B. + :param context: context tensor with context.shape = (*batch_shape, *context.shape). + :return: parameter tensor h with h.shape = (*batch_shape, n_transformer_parameters). If return_mask is True, + also return the event partition mask with shape = event_shape and the constant parameter mask with shape + (n_transformer_parameters,). + """ + # Predict transformer parameters for output dimensions + x_a = x[..., self.coupling_mask.mask] # (*b, *e_A) + h_b = self.conditioner_transform(x_a, context=context) # (*b, p) + return h_b def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: z = x.clone() - h, mask = self.conditioner(x, self.conditioner_transform, context, return_mask=True) - z[..., ~mask], log_det = self.transformer.forward_base(x[..., ~mask], h[..., ~mask, :]) - # TODO make this work with self.transformer.forward + h_b = self.partition_and_predict_parameters(x, context) + z[..., ~self.coupling_mask.mask], log_det = self.transformer.forward(x[..., ~self.coupling_mask.mask], h_b) return z, log_det def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: x = z.clone() - h, mask = self.conditioner(z, self.conditioner_transform, context, return_mask=True) - x[..., ~mask], log_det = self.transformer.inverse_base(z[..., ~mask], h[..., ~mask, :]) - # TODO make this work with self.transformer.inverse + h_b = self.partition_and_predict_parameters(x, context) + x[..., ~self.coupling_mask.mask], log_det = self.transformer.inverse(z[..., ~self.coupling_mask.mask], h_b) return x, log_det class ForwardMaskedAutoregressiveBijection(AutoregressiveBijection): def __init__(self, conditioner: Conditioner, transformer: Transformer, conditioner_transform: ConditionerTransform): - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer.event_shape, conditioner, transformer, conditioner_transform) def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(z, self.event_shape) @@ -76,7 +114,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class InverseMaskedAutoregressiveBijection(AutoregressiveBijection): def __init__(self, conditioner: Conditioner, transformer: Transformer, conditioner_transform: ConditionerTransform): - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer.event_shape, conditioner, transformer, conditioner_transform) self.forward_layer = ForwardMaskedAutoregressiveBijection( conditioner, transformer, @@ -92,5 +130,6 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class ElementwiseBijection(AutoregressiveBijection): def __init__(self, transformer: Transformer, n_transformer_parameters: int): - super().__init__(NullConditioner(), transformer, Constant(transformer.event_shape, n_transformer_parameters)) + super().__init__(transformer.event_shape, NullConditioner(), transformer, + Constant(transformer.event_shape, n_transformer_parameters)) # TODO override forward and inverse to save on space diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py index f4dc654..cad47f4 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py @@ -3,7 +3,7 @@ import torch from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer -from normalizing_flows.utils import sum_except_batch +from normalizing_flows.utils import sum_except_batch, get_batch_shape class MonotonicSpline(Transformer): diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py index e7d03c0..9dd335b 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py @@ -14,11 +14,10 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_bins: int @property def n_parameters(self) -> int: - return 2 * self.n_bins + 2 + return (2 * self.n_bins + 2) * self.n_dim - @property - def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(self.n_parameters,)) + def unflatten_conditioner_parameters(self, h: torch.Tensor): + return torch.unflatten(h, dim=-1, sizes=(*self.event_shape, 2 * self.n_bins + 2)) def compute_spline_parameters(self, knots_x: torch.Tensor, knots_y: torch.Tensor, idx: torch.Tensor): # knots_x.shape == (n, n_knots) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py index d2ad3d0..ee9964b 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py @@ -36,11 +36,10 @@ def compute_bin_y(self, delta): @property def n_parameters(self) -> int: - return self.n_bins + return self.n_bins * self.n_dim - @property - def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(self.n_bins,)) + def unflatten_conditioner_parameters(self, h: torch.Tensor): + return torch.unflatten(h, dim=-1, sizes=(*self.event_shape, self.n_bins)) def forward_1d(self, x, h): assert len(x.shape) == 1 diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py index 2c23edb..07807dc 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py @@ -24,16 +24,10 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], boundary: fl @property def n_parameters(self) -> int: - return 4 * self.n_bins + return 4 * self.n_bins * self.n_dim - @property - def default_parameters(self) -> torch.Tensor: - default_u_x = torch.zeros(size=(self.n_bins,)) - default_u_y = torch.zeros(size=(self.n_bins,)) - default_u_lambda = torch.zeros(size=(self.n_bins,)) - default_u_d = torch.zeros(size=(self.n_bins - 1,)) - default_u_w0 = torch.zeros(size=(1,)) - return torch.cat([default_u_x, default_u_y, default_u_lambda, default_u_d, default_u_w0], dim=0) + def unflatten_conditioner_parameters(self, h: torch.Tensor): + return torch.unflatten(h, dim=-1, sizes=(*self.event_shape, 4 * self.n_bins)) def compute_parameters(self, idx, knots_x, knots_y, knots_d, knots_lambda, u_w0): assert knots_x.shape == knots_y.shape == knots_d.shape diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py index cf5ec86..f41ccff 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py @@ -39,14 +39,10 @@ def __init__(self, @property def n_parameters(self) -> int: - return 3 * self.n_bins - 1 + return (3 * self.n_bins - 1) * self.n_dim - @property - def default_parameters(self) -> torch.Tensor: - default_u_x = torch.zeros(size=(self.n_bins,)) - default_u_y = torch.zeros(size=(self.n_bins,)) - default_u_d = torch.zeros(size=(self.n_bins - 1,)) - return torch.cat([default_u_x, default_u_y, default_u_d], dim=0) + def unflatten_conditioner_parameters(self, h: torch.Tensor): + return torch.unflatten(h, dim=-1, sizes=(*self.event_shape, 3 * self.n_bins - 1)) def compute_bins(self, u, minimum, maximum): bin_sizes = torch.softmax(u, dim=-1) From 8b6870fc7854f93061cddc3c017d37088ddb1f6f Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 17 Nov 2023 09:56:59 -0800 Subject: [PATCH 17/40] Specify correct output size in MADE --- .../autoregressive/conditioner_transforms.py | 7 +++---- .../bijections/finite/autoregressive/layers.py | 18 ++++++++++++------ 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 7398cd0..c8da4c8 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -159,14 +159,13 @@ def __init__(self, layers.extend([self.MaskedLinear(n_layer_inputs, n_layer_outputs, mask), nn.Tanh()]) # Final linear layer - layers.extend([ + layers.append( self.MaskedLinear( masks[-1].shape[1], masks[-1].shape[0] * self.n_predicted_parameters, torch.repeat_interleave(masks[-1], self.n_predicted_parameters, dim=0) - ), - nn.Unflatten(dim=-1, unflattened_size=(self.n_predicted_parameters,)) - ]) + ) + ) self.sequential = nn.Sequential(*layers) @staticmethod diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 6caa572..2cb3f5f 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -171,11 +171,12 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, **kwargs): + event_size = int(torch.prod(torch.as_tensor(event_shape))) transformer = Affine(event_shape=event_shape) conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters // event_size, context_shape=context_shape, **kwargs ) @@ -193,11 +194,12 @@ def __init__(self, context_shape: torch.Size = None, n_bins: int = 8, **kwargs): + event_size = int(torch.prod(torch.as_tensor(event_shape))) transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters // event_size, context_shape=context_shape, **kwargs ) @@ -215,11 +217,12 @@ def __init__(self, context_shape: torch.Size = None, n_bins: int = 8, **kwargs): + event_size = int(torch.prod(torch.as_tensor(event_shape))) transformer = LinearRational(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters // event_size, context_shape=context_shape, **kwargs ) @@ -236,11 +239,12 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, **kwargs): + event_size = int(torch.prod(torch.as_tensor(event_shape))) transformer = invert(Affine(event_shape=event_shape)) conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters // event_size, context_shape=context_shape, **kwargs ) @@ -259,11 +263,12 @@ def __init__(self, n_bins: int = 8, **kwargs): assert n_bins >= 1 + event_size = int(torch.prod(torch.as_tensor(event_shape))) transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters // event_size, context_shape=context_shape, **kwargs ) @@ -282,6 +287,7 @@ def __init__(self, n_hidden_layers: int = 1, hidden_dim: int = 5, **kwargs): + event_size = int(torch.prod(torch.as_tensor(event_shape))) transformer = UnconstrainedMonotonicNeuralNetwork( event_shape=event_shape, n_hidden_layers=n_hidden_layers, @@ -290,7 +296,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + n_transformer_parameters=transformer.n_parameters // event_size, context_shape=context_shape, **kwargs ) From d7adf49e17993aefc549895b205e6c5960dad06b Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 15:54:22 -0800 Subject: [PATCH 18/40] Generalize Transformer to TensorTransformer and create specialized ScalarTransformer --- .../finite/autoregressive/layers_base.py | 12 +++--- .../autoregressive/transformers/affine.py | 10 ++--- .../autoregressive/transformers/base.py | 43 ++++++++++++++++--- .../transformers/combination/base.py | 6 +-- .../transformers/combination/sigmoid.py | 6 +-- .../transformers/convolution.py | 4 +- .../transformers/integration/base.py | 4 +- .../transformers/spline/base.py | 4 +- test/test_reconstruction_transformers.py | 16 +++---- 9 files changed, 68 insertions(+), 37 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index beddbb6..2464735 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -5,13 +5,13 @@ from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner, NullConditioner from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling import Coupling -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.bijections.base import Bijection from normalizing_flows.utils import flatten_event, unflatten_event, get_batch_shape class AutoregressiveBijection(Bijection): - def __init__(self, conditioner: Conditioner, transformer: Transformer, conditioner_transform: ConditionerTransform): + def __init__(self, conditioner: Conditioner, transformer: ScalarTransformer, conditioner_transform: ConditionerTransform): super().__init__(event_shape=transformer.event_shape) self.conditioner = conditioner self.conditioner_transform = conditioner_transform @@ -29,7 +29,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class CouplingBijection(AutoregressiveBijection): - def __init__(self, conditioner: Coupling, transformer: Transformer, conditioner_transform: ConditionerTransform, + def __init__(self, conditioner: Coupling, transformer: ScalarTransformer, conditioner_transform: ConditionerTransform, **kwargs): super().__init__(conditioner, transformer, conditioner_transform, **kwargs) @@ -51,7 +51,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class ForwardMaskedAutoregressiveBijection(AutoregressiveBijection): - def __init__(self, conditioner: Conditioner, transformer: Transformer, conditioner_transform: ConditionerTransform): + def __init__(self, conditioner: Conditioner, transformer: ScalarTransformer, conditioner_transform: ConditionerTransform): super().__init__(conditioner, transformer, conditioner_transform) def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: @@ -73,7 +73,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class InverseMaskedAutoregressiveBijection(AutoregressiveBijection): - def __init__(self, conditioner: Conditioner, transformer: Transformer, conditioner_transform: ConditionerTransform): + def __init__(self, conditioner: Conditioner, transformer: ScalarTransformer, conditioner_transform: ConditionerTransform): super().__init__(conditioner, transformer, conditioner_transform) self.forward_layer = ForwardMaskedAutoregressiveBijection( conditioner, @@ -89,6 +89,6 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class ElementwiseBijection(AutoregressiveBijection): - def __init__(self, transformer: Transformer, n_transformer_parameters: int): + def __init__(self, transformer: ScalarTransformer, n_transformer_parameters: int): super().__init__(NullConditioner(), transformer, Constant(transformer.event_shape, n_transformer_parameters)) # TODO override forward and inverse to save on space diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py b/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py index 4e0b2ba..6f80fbe 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py @@ -3,11 +3,11 @@ import torch -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.utils import get_batch_shape, sum_except_batch -class Affine(Transformer): +class Affine(ScalarTransformer): """ Affine transformer. @@ -55,7 +55,7 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch return (z - beta) / alpha, log_det -class Affine2(Transformer): +class Affine2(ScalarTransformer): """ Affine transformer with near-identity initialization. @@ -110,7 +110,7 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch return (z - beta) / alpha, log_det -class Shift(Transformer): +class Shift(ScalarTransformer): def __init__(self, event_shape: torch.Size): super().__init__(event_shape=event_shape) @@ -135,7 +135,7 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch return z - beta, log_det -class Scale(Transformer): +class Scale(ScalarTransformer): """ Scaling transformer. diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/base.py index b23ecdf..645b82d 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/base.py @@ -5,7 +5,17 @@ from normalizing_flows.bijections.base import Bijection -class Transformer(Bijection): +class TensorTransformer(Bijection): + """ + Base transformer class. + + A transformer receives as input a tensor x with x.shape = (*batch_shape, *event_shape) and parameters h + with h.shape = (*batch_shape, *parameter_shape). It applies a bijective map to each tensor in the batch + with its corresponding parameter set. In general, the parameters are used to transform the entire tensor at + once. As a special case, the subclass ScalarTransformer transforms each element of an input event + individually. + """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): super().__init__(event_shape=event_shape) @@ -15,18 +25,39 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch def inverse(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError + @property + def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: + raise NotImplementedError + @property def n_parameters(self) -> int: + return int(torch.prod(torch.as_tensor(self.parameter_shape))) + + @property + def default_parameters(self) -> torch.Tensor: """ - Number of parameters that parametrize this transformer. Example: rational quadratic splines require 3*b-1 where - b is the number of bins. An affine transformation requires 2 (typically corresponding to the unconstrained scale - and shift). + Set of parameters which ensures an identity transformation. """ raise NotImplementedError + +class ScalarTransformer(TensorTransformer): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): + super().__init__(event_shape) + @property - def default_parameters(self) -> torch.Tensor: + def parameter_shape_per_element(self): """ - Set of parameters which ensures an identity transformation. + The shape of parameters that transform a single element of an input tensor. + + Example: + * if using an affine transformer, this is equal to (2,) (corresponding to scale and shift). + * if using a rational quadratic spline transformer, this is equal to (3 * b - 1,) where b is the + number of bins. """ raise NotImplementedError + + @property + def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: + # Scalar transformers map each element individually, so the first dimensions are the event shape + return torch.Size((*self.event_shape, *self.parameter_shape_per_element)) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py index fb431ed..871de2b 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py @@ -1,12 +1,12 @@ import torch from typing import Tuple, List -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.utils import get_batch_shape -class Combination(Transformer): - def __init__(self, event_shape: torch.Size, components: List[Transformer]): +class Combination(ScalarTransformer): + def __init__(self, event_shape: torch.Size, components: List[ScalarTransformer]): super().__init__(event_shape) self.components = components self.n_components = len(self.components) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py index 43f259b..38378f0 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py @@ -2,7 +2,7 @@ from typing import Tuple, Union, List import torch import torch.nn as nn -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.bijections.finite.autoregressive.transformers.combination.base import Combination from normalizing_flows.bijections.finite.autoregressive.transformers.combination.sigmoid_util import log_softmax, \ log_sigmoid, log_dot @@ -19,7 +19,7 @@ def inverse_sigmoid(p): return torch.log(p) - torch.log1p(-p) -class Sigmoid(Transformer): +class Sigmoid(ScalarTransformer): """ Applies z = inv_sigmoid(w.T @ sigmoid(a * x + b)) where a > 0, w > 0 and sum(w) = 1. Note: w, a, b are vectors, so multiplication a * x is broadcast. @@ -212,7 +212,7 @@ def forward_1d(self, x, h, eps: float = 1e-6): return z, log_det.view(*x.shape[:2]) -class DenseSigmoid(Transformer): +class DenseSigmoid(ScalarTransformer): """ Apply y = f1 \\circ f2 \\circ ... \\circ fn (x) where * f1 is a dense sigmoid inner transform which maps from 1 to h dimensions; diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py b/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py index ff9ed4b..b65c76e 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py @@ -1,7 +1,7 @@ from typing import Union, Tuple import torch -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer def construct_kernels_plu( @@ -55,7 +55,7 @@ def construct_kernels_plu( return kernels -class Invertible1x1Convolution(Transformer): +class Invertible1x1Convolution(ScalarTransformer): """ Invertible 1x1 convolution. diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py index 56ab9fc..8f71343 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py @@ -3,12 +3,12 @@ import torch -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.bijections.numerical_inversion import bisection from normalizing_flows.utils import get_batch_shape, sum_except_batch -class Integration(Transformer): +class Integration(ScalarTransformer): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], bound: float = 100.0, eps: float = 1e-6): """ :param bound: specifies the initial interval [-bound, bound] where numerical inversion is performed. diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py index 2caa2dc..cec05a0 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py @@ -2,11 +2,11 @@ import torch -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.utils import sum_except_batch -class MonotonicSpline(Transformer): +class MonotonicSpline(ScalarTransformer): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], min_input: float = -1.0, diff --git a/test/test_reconstruction_transformers.py b/test/test_reconstruction_transformers.py index b026627..8da838f 100644 --- a/test/test_reconstruction_transformers.py +++ b/test/test_reconstruction_transformers.py @@ -3,7 +3,7 @@ import pytest import torch -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearSpline from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear_rational import \ LinearRational as LinearRationalSpline @@ -21,7 +21,7 @@ from test.constants import __test_constants -def setup_transformer_data(transformer_class: Transformer, batch_shape, event_shape): +def setup_transformer_data(transformer_class: ScalarTransformer, batch_shape, event_shape): # vector_to_vector: does the transformer map a vector to vector? Otherwise, it maps a scalar to scalar. torch.manual_seed(0) transformer = transformer_class(event_shape) @@ -30,7 +30,7 @@ def setup_transformer_data(transformer_class: Transformer, batch_shape, event_sh return transformer, x, h -def assert_valid_reconstruction(transformer: Transformer, +def assert_valid_reconstruction(transformer: ScalarTransformer, x: torch.Tensor, h: torch.Tensor, reconstruction_eps: float = 1e-3, @@ -67,7 +67,7 @@ def assert_valid_reconstruction(transformer: Transformer, ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) -def test_affine(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): +def test_affine(transformer_class: ScalarTransformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) assert_valid_reconstruction(transformer, x, h) @@ -81,7 +81,7 @@ def test_affine(transformer_class: Transformer, batch_shape: Tuple, event_shape: ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) -def test_spline(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): +def test_spline(transformer_class: ScalarTransformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) assert_valid_reconstruction(transformer, x, h) @@ -91,7 +91,7 @@ def test_spline(transformer_class: Transformer, batch_shape: Tuple, event_shape: ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) -def test_integration(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): +def test_integration(transformer_class: ScalarTransformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) assert_valid_reconstruction(transformer, x, h) @@ -99,7 +99,7 @@ def test_integration(transformer_class: Transformer, batch_shape: Tuple, event_s @pytest.mark.parametrize('transformer_class', [Sigmoid, DeepSigmoid]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) -def test_combination_basic(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): +def test_combination_basic(transformer_class: ScalarTransformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) assert_valid_reconstruction(transformer, x, h) @@ -110,7 +110,7 @@ def test_combination_basic(transformer_class: Transformer, batch_shape: Tuple, e ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) -def test_combination_vector_to_vector(transformer_class: Transformer, batch_shape: Tuple, event_shape: Tuple): +def test_combination_vector_to_vector(transformer_class: ScalarTransformer, batch_shape: Tuple, event_shape: Tuple): transformer, x, h = setup_transformer_data(transformer_class, batch_shape, event_shape) assert_valid_reconstruction(transformer, x, h) From ed30a0341b15bc939d38c0828cd919a9765d0108 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 16:46:12 -0800 Subject: [PATCH 19/40] Work with arbitrary parameter shapes instead of flat events --- .../autoregressive/conditioner_transforms.py | 95 ++++++++++--------- .../finite/autoregressive/layers.py | 32 +++---- .../finite/autoregressive/layers_base.py | 4 +- .../autoregressive/transformers/affine.py | 21 ++-- test/test_conditioner_transforms.py | 2 +- 5 files changed, 79 insertions(+), 75 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index a3f00d6..f961145 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -1,4 +1,5 @@ import math +from typing import Tuple, Union import torch import torch.nn as nn @@ -23,24 +24,21 @@ def __init__(self, input_event_shape, context_shape, output_event_shape, - n_transformer_parameters: int, + parameter_shape: Union[torch.Size, Tuple[int, ...]], context_combiner: ContextCombiner = None, - percent_global_parameters: float = 0.0, + global_parameter_mask: torch.Tensor = None, initial_global_parameter_value: float = None): """ :param input_event_shape: shape of conditioner input tensor x. :param context_shape: shape of conditioner context tensor c. :param output_event_shape: shape of transformer input tensor y. - :param n_transformer_parameters: number of parameters required to transform a single element of y. + :param parameter_shape: shape of parameter tensor required to transform transformer input y. :param context_combiner: ContextCombiner class which defines how to combine x and c to predict theta. - :param percent_global_parameters: percent of all parameters in theta to be learned independent of x and c. - A value of 0 means the conditioner predicts n_transformer_parameters parameters based on x and c. - A value of 1 means the conditioner predicts no parameters based on x and c, but outputs globally learned theta. - A value of alpha means the conditioner outputs alpha * n_transformer_parameters parameters globally and - predicts the rest. In this case, the predicted parameters are the last alpha * n_transformer_parameters - elements in theta. - :param initial_global_parameter_value: the initial value for the entire globally learned part of theta. If None, - the global part of theta is initialized to samples from the standard normal distribution. + :param global_parameter_mask: boolean tensor which determines which elements of parameter tensors should be + learned globally instead of predicted. If an element is set to 1, that element is learned globally. + We require that global_parameter_mask.shape = parameter_shape. + :param initial_global_parameter_value: initial global parameter value as a single scalar. If None, all initial + global parameters are independently drawn from the standard normal distribution. """ super().__init__() if context_shape is None: @@ -55,50 +53,57 @@ def __init__(self, self.context_shape = context_shape self.n_input_event_dims = self.context_combiner.n_output_dims self.n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) - self.n_transformer_parameters = n_transformer_parameters - self.n_global_parameters = int(n_transformer_parameters * percent_global_parameters) + + # Setup output parameter attributes + self.parameter_shape = parameter_shape + self.global_parameter_mask = global_parameter_mask + self.n_transformer_parameters = int(torch.prod(torch.as_tensor(self.parameter_shape))) + self.n_global_parameters = 0 if global_parameter_mask is None else int(torch.sum(self.global_parameter_mask)) self.n_predicted_parameters = self.n_transformer_parameters - self.n_global_parameters if initial_global_parameter_value is None: - initial_global_theta = torch.randn(size=(*output_event_shape, self.n_global_parameters)) + initial_global_theta_flat = torch.randn(size=(self.n_global_parameters,)) else: - initial_global_theta = torch.full( - size=(*output_event_shape, self.n_global_parameters), + initial_global_theta_flat = torch.full( + size=(self.n_global_parameters,), fill_value=initial_global_parameter_value ) - self.global_theta = nn.Parameter(initial_global_theta) + self.global_theta_flat = nn.Parameter(initial_global_theta_flat) def forward(self, x: torch.Tensor, context: torch.Tensor = None): - # x.shape = (*batch_shape, *input_event_shape) - # context.shape = (*batch_shape, *context_shape) - # output.shape = (*batch_shape, *output_event_shape, n_transformer_parameters) + # x.shape = (*batch_shape, *self.input_event_shape) + # context.shape = (*batch_shape, *self.context_shape) + # output.shape = (*batch_shape, *self.parameter_shape) if self.n_global_parameters == 0: - return self.predict_theta(x, context) + # All parameters are predicted + return self.predict_theta_flat(x, context) else: - n_batch_dims = len(x.shape) - len(self.output_event_shape) - n_event_dims = len(self.output_event_shape) - batch_shape = x.shape[:n_batch_dims] - batch_global_theta = pad_leading_dims(self.global_theta, n_batch_dims).repeat( - *batch_shape, *([1] * n_event_dims), 1 - ) + batch_shape = x.shape[:-len(self.output_event_shape)] if self.n_global_parameters == self.n_transformer_parameters: - return batch_global_theta + # All transformer parameters are learned globally + output = torch.zeros(*batch_shape, *self.parameter_shape) + output[..., self.global_parameter_mask] = self.global_theta_flat + return output else: - return torch.cat([batch_global_theta, self.predict_theta(x, context)], dim=-1) + # Some transformer parameters are learned globally, some are predicted + output = torch.zeros(*batch_shape, *self.parameter_shape) + output[..., self.global_parameter_mask] = self.global_theta_flat + output[..., ~self.global_parameter_mask] = self.predict_theta_flat(x, context) + return output - def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): raise NotImplementedError class Constant(ConditionerTransform): - def __init__(self, output_event_shape, n_parameters: int, fill_value: float = None): + def __init__(self, output_event_shape, parameter_shape, fill_value: float = None): super().__init__( input_event_shape=None, context_shape=None, output_event_shape=output_event_shape, - n_transformer_parameters=n_parameters, + parameter_shape=parameter_shape, initial_global_parameter_value=fill_value, - percent_global_parameters=1.0 + global_parameter_mask=torch.ones(parameter_shape, dtype=torch.bool) ) @@ -114,7 +119,7 @@ def forward(self, x): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, - n_transformer_parameters: int, + parameter_shape: int, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2, @@ -123,7 +128,7 @@ def __init__(self, input_event_shape=input_event_shape, context_shape=context_shape, output_event_shape=output_event_shape, - n_transformer_parameters=n_transformer_parameters, + parameter_shape=parameter_shape, **kwargs ) @@ -169,19 +174,19 @@ def create_masks(n_layers, ms): masks.append(torch.as_tensor(xx >= yy, dtype=torch.float)) return masks - def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): out = self.sequential(self.context_combiner(x, context)) batch_shape = get_batch_shape(x, self.input_event_shape) return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) class LinearMADE(MADE): - def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, n_transformer_parameters: int, + def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, parameter_shape: int, **kwargs): super().__init__( input_event_shape, output_event_shape, - n_transformer_parameters, + parameter_shape, n_layers=1, **kwargs ) @@ -191,7 +196,7 @@ class FeedForward(ConditionerTransform): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, - n_transformer_parameters: int, + parameter_shape: int, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2, @@ -200,7 +205,7 @@ def __init__(self, input_event_shape=input_event_shape, context_shape=context_shape, output_event_shape=output_event_shape, - n_transformer_parameters=n_transformer_parameters, + parameter_shape=parameter_shape, **kwargs ) @@ -215,7 +220,7 @@ def __init__(self, # Check the one layer special case if n_layers == 1: - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * n_transformer_parameters)) + layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * parameter_shape)) elif n_layers > 1: layers.extend([nn.Linear(self.n_input_event_dims, n_hidden), nn.Tanh()]) for _ in range(n_layers - 2): @@ -228,7 +233,7 @@ def __init__(self, layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters))) self.sequential = nn.Sequential(*layers) - def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): out = self.sequential(self.context_combiner(x, context)) batch_shape = get_batch_shape(x, self.input_event_shape) return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) @@ -251,7 +256,7 @@ def forward(self, x): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, - n_transformer_parameters: int, + parameter_shape: int, context_shape: torch.Size = None, n_layers: int = 2, **kwargs): @@ -259,7 +264,7 @@ def __init__(self, input_event_shape, context_shape, output_event_shape, - n_transformer_parameters, + parameter_shape, **kwargs ) @@ -284,7 +289,7 @@ def __init__(self, layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters))) self.sequential = nn.Sequential(*layers) - def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): out = self.sequential(self.context_combiner(x, context)) batch_shape = get_batch_shape(x, self.input_event_shape) return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 6ab4c79..56c4c15 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -19,25 +19,25 @@ class ElementwiseAffine(ElementwiseBijection): def __init__(self, event_shape, **kwargs): transformer = Affine(event_shape, **kwargs) - super().__init__(transformer, n_transformer_parameters=transformer.n_parameters) + super().__init__(transformer) class ElementwiseScale(ElementwiseBijection): def __init__(self, event_shape, **kwargs): transformer = Scale(event_shape, **kwargs) - super().__init__(transformer, n_transformer_parameters=transformer.n_parameters) + super().__init__(transformer) class ElementwiseShift(ElementwiseBijection): def __init__(self, event_shape): transformer = Shift(event_shape) - super().__init__(transformer, n_transformer_parameters=transformer.n_parameters) + super().__init__(transformer) class ElementwiseRQSpline(ElementwiseBijection): def __init__(self, event_shape, **kwargs): transformer = RationalQuadratic(event_shape, **kwargs) - super().__init__(transformer, n_transformer_parameters=transformer.n_parameters) + super().__init__(transformer) class AffineCoupling(CouplingBijection): @@ -52,7 +52,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -71,7 +71,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -88,7 +88,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -107,7 +107,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -125,7 +125,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -145,7 +145,7 @@ def __init__(self, conditioner_transform = FeedForward( input_event_shape=conditioner.input_shape, output_event_shape=conditioner.output_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -181,7 +181,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -203,7 +203,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -224,7 +224,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -245,7 +245,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -268,7 +268,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) @@ -295,7 +295,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters, + parameter_shape=transformer.n_parameters, context_shape=context_shape, **kwargs ) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 2464735..35c0fc3 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -89,6 +89,6 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class ElementwiseBijection(AutoregressiveBijection): - def __init__(self, transformer: ScalarTransformer, n_transformer_parameters: int): - super().__init__(NullConditioner(), transformer, Constant(transformer.event_shape, n_transformer_parameters)) + def __init__(self, transformer: ScalarTransformer): + super().__init__(NullConditioner(), transformer, Constant(transformer.event_shape, transformer.parameter_shape)) # TODO override forward and inverse to save on space diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py b/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py index 6f80fbe..9b2bb2b 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py @@ -23,14 +23,12 @@ def __init__(self, event_shape: torch.Size, min_scale: float = 1e-3): self.const = 2 @property - def n_parameters(self) -> int: - return 2 + def parameter_shape_per_element(self): + return (2,) @property def default_parameters(self) -> torch.Tensor: - default_u_alpha = torch.zeros(size=(1,)) - default_u_beta = torch.zeros(size=(1,)) - return torch.cat([default_u_alpha, default_u_beta], dim=0) + return torch.zeros(self.parameter_shape) def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: u_alpha = h[..., 0] @@ -115,12 +113,12 @@ def __init__(self, event_shape: torch.Size): super().__init__(event_shape=event_shape) @property - def n_parameters(self) -> int: - return 1 + def parameter_shape_per_element(self): + return (1,) @property def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(1,)) + return torch.zeros(self.parameter_shape) def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: beta = h[..., 0] @@ -150,11 +148,12 @@ def __init__(self, event_shape: torch.Size, min_scale: float = 1e-3): self.u_alpha_1 = math.log(1 - self.m) @property - def n_parameters(self) -> int: - return 1 + def parameter_shape_per_element(self): + return (1,) + @property def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(1,)) + return torch.zeros(self.parameter_shape) def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: u_alpha = h[..., 0] diff --git a/test/test_conditioner_transforms.py b/test/test_conditioner_transforms.py index c74f0b9..b7d6edd 100644 --- a/test/test_conditioner_transforms.py +++ b/test/test_conditioner_transforms.py @@ -24,7 +24,7 @@ def test_shape(transform_class, batch_shape, input_event_shape, output_event_sha transform = transform_class( input_event_shape=input_event_shape, output_event_shape=output_event_shape, - n_transformer_parameters=n_predicted_parameters + parameter_shape=n_predicted_parameters ) out = transform(x) assert out.shape == (*batch_shape, *output_event_shape, n_predicted_parameters) From 519bce93636d51a7662c15e9cdede091a23467f2 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 19:30:43 -0800 Subject: [PATCH 20/40] Fixing some merge errors --- .../autoregressive/conditioner_transforms.py | 92 +++++-------------- .../finite/autoregressive/layers.py | 47 ++++------ .../finite/autoregressive/layers_base.py | 27 +++--- 3 files changed, 53 insertions(+), 113 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index f961145..8e12705 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -1,5 +1,5 @@ import math -from typing import Tuple, Union +from typing import Tuple, Union, Type import torch import torch.nn as nn @@ -23,7 +23,6 @@ class ConditionerTransform(nn.Module): def __init__(self, input_event_shape, context_shape, - output_event_shape, parameter_shape: Union[torch.Size, Tuple[int, ...]], context_combiner: ContextCombiner = None, global_parameter_mask: torch.Tensor = None, @@ -31,7 +30,6 @@ def __init__(self, """ :param input_event_shape: shape of conditioner input tensor x. :param context_shape: shape of conditioner context tensor c. - :param output_event_shape: shape of transformer input tensor y. :param parameter_shape: shape of parameter tensor required to transform transformer input y. :param context_combiner: ContextCombiner class which defines how to combine x and c to predict theta. :param global_parameter_mask: boolean tensor which determines which elements of parameter tensors should be @@ -49,10 +47,8 @@ def __init__(self, # The conditioner transform receives as input the context combiner output self.input_event_shape = input_event_shape - self.output_event_shape = output_event_shape self.context_shape = context_shape self.n_input_event_dims = self.context_combiner.n_output_dims - self.n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) # Setup output parameter attributes self.parameter_shape = parameter_shape @@ -78,7 +74,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None): # All parameters are predicted return self.predict_theta_flat(x, context) else: - batch_shape = x.shape[:-len(self.output_event_shape)] + batch_shape = get_batch_shape(x, self.input_event_shape) if self.n_global_parameters == self.n_transformer_parameters: # All transformer parameters are learned globally output = torch.zeros(*batch_shape, *self.parameter_shape) @@ -96,11 +92,10 @@ def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): class Constant(ConditionerTransform): - def __init__(self, output_event_shape, parameter_shape, fill_value: float = None): + def __init__(self, parameter_shape, fill_value: float = None): super().__init__( input_event_shape=None, context_shape=None, - output_event_shape=output_event_shape, parameter_shape=parameter_shape, initial_global_parameter_value=fill_value, global_parameter_mask=torch.ones(parameter_shape, dtype=torch.bool) @@ -119,7 +114,7 @@ def forward(self, x): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, - parameter_shape: int, + parameter_shape: torch.Size, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2, @@ -127,10 +122,10 @@ def __init__(self, super().__init__( input_event_shape=input_event_shape, context_shape=context_shape, - output_event_shape=output_event_shape, parameter_shape=parameter_shape, **kwargs ) + n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) if n_hidden is None: n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) @@ -139,7 +134,7 @@ def __init__(self, ms = [ torch.arange(self.n_input_event_dims) + 1, *[(torch.arange(n_hidden) % (self.n_input_event_dims - 1)) + 1 for _ in range(n_layers - 1)], - torch.arange(self.n_output_event_dims) + 1 + torch.arange(n_output_event_dims) + 1 ] # Create autoencoder masks @@ -156,8 +151,7 @@ def __init__(self, masks[-1].shape[1], masks[-1].shape[0] * self.n_predicted_parameters, torch.repeat_interleave(masks[-1], self.n_predicted_parameters, dim=0) - ), - nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters)) + ) ]) self.sequential = nn.Sequential(*layers) @@ -175,13 +169,14 @@ def create_masks(n_layers, ms): return masks def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): - out = self.sequential(self.context_combiner(x, context)) - batch_shape = get_batch_shape(x, self.input_event_shape) - return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) + return self.sequential(self.context_combiner(x, context)) class LinearMADE(MADE): - def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, parameter_shape: int, + def __init__(self, + input_event_shape: torch.Size, + output_event_shape: torch.Size, + parameter_shape: torch.Size, **kwargs): super().__init__( input_event_shape, @@ -195,16 +190,15 @@ def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size class FeedForward(ConditionerTransform): def __init__(self, input_event_shape: torch.Size, - output_event_shape: torch.Size, - parameter_shape: int, + parameter_shape: torch.Size, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2, + hidden_linear_module: Type[nn.Module] = nn.Linear, **kwargs): super().__init__( input_event_shape=input_event_shape, context_shape=context_shape, - output_event_shape=output_event_shape, parameter_shape=parameter_shape, **kwargs ) @@ -220,23 +214,19 @@ def __init__(self, # Check the one layer special case if n_layers == 1: - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * parameter_shape)) + layers.append(nn.Linear(self.n_input_event_dims, self.n_predicted_parameters)) elif n_layers > 1: - layers.extend([nn.Linear(self.n_input_event_dims, n_hidden), nn.Tanh()]) + layers.extend([hidden_linear_module(self.n_input_event_dims, n_hidden), nn.Tanh()]) for _ in range(n_layers - 2): - layers.extend([nn.Linear(n_hidden, n_hidden), nn.Tanh()]) - layers.append(nn.Linear(n_hidden, self.n_output_event_dims * self.n_predicted_parameters)) + layers.extend([hidden_linear_module(n_hidden, n_hidden), nn.Tanh()]) + layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) else: raise ValueError - # Reshape the output - layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters))) self.sequential = nn.Sequential(*layers) def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): - out = self.sequential(self.context_combiner(x, context)) - batch_shape = get_batch_shape(x, self.input_event_shape) - return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) + return self.sequential(self.context_combiner(x, context)) class Linear(FeedForward): @@ -244,8 +234,8 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, n_layers=1) -class ResidualFeedForward(ConditionerTransform): - class ResidualLinear(nn.Module): +class ResidualFeedForward(FeedForward): + class _ResidualLinearModule(nn.Module): def __init__(self, n_in, n_out): super().__init__() self.linear = nn.Linear(n_in, n_out) @@ -253,43 +243,9 @@ def __init__(self, n_in, n_out): def forward(self, x): return x + self.linear(x) - def __init__(self, - input_event_shape: torch.Size, - output_event_shape: torch.Size, - parameter_shape: int, - context_shape: torch.Size = None, - n_layers: int = 2, - **kwargs): + def __init__(self, *args, **kwargs): super().__init__( - input_event_shape, - context_shape, - output_event_shape, - parameter_shape, + *args, + hidden_linear_module=ResidualFeedForward._ResidualLinearModule, **kwargs ) - - # If context given, concatenate it to transform input - if context_shape is not None: - self.n_input_event_dims += self.n_context_dims - - layers = [] - - # Check the one layer special case - if n_layers == 1: - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * self.n_predicted_parameters)) - elif n_layers > 1: - layers.extend([self.ResidualLinear(self.n_input_event_dims, self.n_input_event_dims), nn.Tanh()]) - for _ in range(n_layers - 2): - layers.extend([self.ResidualLinear(self.n_input_event_dims, self.n_input_event_dims), nn.Tanh()]) - layers.append(nn.Linear(self.n_input_event_dims, self.n_output_event_dims * self.n_predicted_parameters)) - else: - raise ValueError - - # Reshape the output - layers.append(nn.Unflatten(dim=-1, unflattened_size=(*output_event_shape, self.n_predicted_parameters))) - self.sequential = nn.Sequential(*layers) - - def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): - out = self.sequential(self.context_combiner(x, context)) - batch_shape = get_batch_shape(x, self.input_event_shape) - return out.view(*batch_shape, *self.output_event_shape, self.n_predicted_parameters) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 81912a2..6d50c65 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -48,13 +48,11 @@ def __init__(self, if event_shape == (1,): raise ValueError coupling_mask = HalfSplit(event_shape) - transformer = Affine(event_shape=(coupling_mask.transformed_event_size,)) + transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), n_transformer_parameters=transformer.n_parameters, - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - parameter_shape=transformer.n_parameters, + parameter_shape=torch.Size((transformer.n_parameters,)), context_shape=context_shape, **kwargs ) @@ -69,13 +67,11 @@ def __init__(self, if event_shape == (1,): raise ValueError coupling_mask = HalfSplit(event_shape) - transformer = Affine(event_shape=(coupling_mask.transformed_event_size,)).invert() + transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))).invert() conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), n_transformer_parameters=transformer.n_parameters, - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - parameter_shape=transformer.n_parameters, + parameter_shape=torch.Size((transformer.n_parameters,)), context_shape=context_shape, **kwargs ) @@ -88,13 +84,10 @@ def __init__(self, context_shape: torch.Size = None, **kwargs): coupling_mask = HalfSplit(event_shape) - transformer = Shift(event_shape=(coupling_mask.transformed_event_size,)) + transformer = Shift(event_shape=torch.Size((coupling_mask.transformed_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - n_transformer_parameters=transformer.n_parameters, - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - parameter_shape=transformer.n_parameters, + parameter_shape=torch.Size((transformer.n_parameters,)), context_shape=context_shape, **kwargs ) @@ -109,13 +102,10 @@ def __init__(self, **kwargs): assert n_bins >= 1 coupling_mask = HalfSplit(event_shape) - transformer = LinearRational(event_shape=(coupling_mask.transformed_event_size,), n_bins=n_bins) + transformer = LinearRational(event_shape=torch.Size((coupling_mask.transformed_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - n_transformer_parameters=transformer.n_parameters, - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - parameter_shape=transformer.n_parameters, + parameter_shape=torch.Size((transformer.n_parameters,)), context_shape=context_shape, **kwargs ) @@ -129,13 +119,10 @@ def __init__(self, n_bins: int = 8, **kwargs): coupling_mask = HalfSplit(event_shape) - transformer = RationalQuadratic(event_shape=(coupling_mask.transformed_event_size,), n_bins=n_bins) + transformer = RationalQuadratic(event_shape=torch.Size((coupling_mask.transformed_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - n_transformer_parameters=transformer.n_parameters, - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - parameter_shape=transformer.n_parameters, + parameter_shape=torch.Size((transformer.n_parameters,)), context_shape=context_shape, **kwargs ) @@ -149,15 +136,15 @@ def __init__(self, n_hidden_layers: int = 2, **kwargs): coupling_mask = HalfSplit(event_shape) - transformer = DeepSigmoid(event_shape=(coupling_mask.transformed_event_size,), n_hidden_layers=n_hidden_layers) + transformer = DeepSigmoid( + event_shape=torch.Size((coupling_mask.transformed_event_size,)), + n_hidden_layers=n_hidden_layers + ) # Parameter order: [c1, c2, c3, c4, ..., ck] for all components # Each component has parameter order [a_unc, b, w_unc] conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - n_transformer_parameters=transformer.n_parameters, - input_event_shape=conditioner.input_shape, - output_event_shape=conditioner.output_shape, - parameter_shape=transformer.n_parameters, + parameter_shape=torch.Size((transformer.n_parameters,)), context_shape=context_shape, **kwargs ) @@ -189,13 +176,11 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, **kwargs): - event_size = int(torch.prod(torch.as_tensor(event_shape))) transformer = Affine(event_shape=event_shape) conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters // event_size, - parameter_shape=transformer.n_parameters, + parameter_shape=torch.Size((transformer.n_parameters,)), context_shape=context_shape, **kwargs ) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index b7dc250..a71d589 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -5,9 +5,7 @@ from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner, NullConditioner from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import CouplingMask -from normalizing_flows.bijections.finite.autoregressive.transformers.base import Transformer -from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling import Coupling -from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer +from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer, ScalarTransformer from normalizing_flows.bijections.base import Bijection from normalizing_flows.utils import flatten_event, unflatten_event, get_batch_shape @@ -16,7 +14,7 @@ class AutoregressiveBijection(Bijection): def __init__(self, event_shape, conditioner: Optional[Conditioner], - transformer: Transformer, + transformer: TensorTransformer, conditioner_transform: ConditionerTransform, **kwargs): super().__init__(event_shape=event_shape) @@ -54,7 +52,7 @@ class CouplingBijection(AutoregressiveBijection): """ def __init__(self, - transformer: Transformer, + transformer: TensorTransformer, coupling_mask: CouplingMask, conditioner_transform: ConditionerTransform, **kwargs): @@ -70,13 +68,11 @@ def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tenso :param x: input tensor with x.shape = (*batch_shape, *event_shape) to be partitioned into x_A and x_B. :param context: context tensor with context.shape = (*batch_shape, *context.shape). - :return: parameter tensor h with h.shape = (*batch_shape, n_transformer_parameters). If return_mask is True, - also return the event partition mask with shape = event_shape and the constant parameter mask with shape - (n_transformer_parameters,). + :return: parameter tensor h with h.shape = (*batch_shape, *parameter_shape). """ # Predict transformer parameters for output dimensions - x_a = x[..., self.coupling_mask.mask] # (*b, *e_A) - h_b = self.conditioner_transform(x_a, context=context) # (*b, p) + x_a = x[..., self.coupling_mask.mask] # (*b, constant_event_size) + h_b = self.conditioner_transform(x_a, context=context) # (*b, *p) return h_b def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: @@ -93,7 +89,8 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class ForwardMaskedAutoregressiveBijection(AutoregressiveBijection): - def __init__(self, conditioner: Conditioner, transformer: ScalarTransformer, conditioner_transform: ConditionerTransform): + def __init__(self, conditioner: Conditioner, transformer: ScalarTransformer, + conditioner_transform: ConditionerTransform): super().__init__(conditioner, transformer, conditioner_transform) def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: @@ -115,7 +112,8 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class InverseMaskedAutoregressiveBijection(AutoregressiveBijection): - def __init__(self, conditioner: Conditioner, transformer: ScalarTransformer, conditioner_transform: ConditionerTransform): + def __init__(self, conditioner: Conditioner, transformer: ScalarTransformer, + conditioner_transform: ConditionerTransform): super().__init__(conditioner, transformer, conditioner_transform) self.forward_layer = ForwardMaskedAutoregressiveBijection( conditioner, @@ -131,6 +129,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class ElementwiseBijection(AutoregressiveBijection): - def __init__(self, transformer: ScalarTransformer): - super().__init__(NullConditioner(), transformer, Constant(transformer.event_shape, transformer.parameter_shape)) + def __init__(self, transformer: ScalarTransformer, fill_value: float = None): + super().__init__(transformer.event_shape, NullConditioner(), transformer, + Constant(transformer.event_shape, fill_value=fill_value)) # TODO override forward and inverse to save on space From 235a93281146cf7a95928adb93e46aa9b45cf6c0 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 19:39:48 -0800 Subject: [PATCH 21/40] Fix elementwise bijection --- .../autoregressive/conditioner_transforms.py | 4 ++-- .../finite/autoregressive/layers_base.py | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 8e12705..7469222 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -92,9 +92,9 @@ def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): class Constant(ConditionerTransform): - def __init__(self, parameter_shape, fill_value: float = None): + def __init__(self, event_shape, parameter_shape, fill_value: float = None): super().__init__( - input_event_shape=None, + input_event_shape=event_shape, context_shape=None, parameter_shape=parameter_shape, initial_global_parameter_value=fill_value, diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index a71d589..8c31196 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -129,7 +129,17 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class ElementwiseBijection(AutoregressiveBijection): + """ + Base elementwise bijection class. + + Applies a bijective transformation to each element of the input tensor. + The bijection for each element has its own set of globally learned parameters. + """ + def __init__(self, transformer: ScalarTransformer, fill_value: float = None): - super().__init__(transformer.event_shape, NullConditioner(), transformer, - Constant(transformer.event_shape, fill_value=fill_value)) - # TODO override forward and inverse to save on space + super().__init__( + transformer.event_shape, + NullConditioner(), + transformer, + Constant(transformer.event_shape, transformer.parameter_shape, fill_value=fill_value) + ) From 0cd67e27e3a1ac553e5f5a838f4f580e38aa0eb8 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 19:47:29 -0800 Subject: [PATCH 22/40] Implement abstract parameter shape property in splines --- .../bijections/finite/autoregressive/layers.py | 2 -- .../finite/autoregressive/transformers/spline/base.py | 2 +- .../finite/autoregressive/transformers/spline/cubic.py | 7 ++----- .../finite/autoregressive/transformers/spline/linear.py | 7 ++----- .../autoregressive/transformers/spline/linear_rational.py | 6 ++---- .../transformers/spline/rational_quadratic.py | 6 ++---- 6 files changed, 9 insertions(+), 21 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 6d50c65..fa7d2a2 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -51,7 +51,6 @@ def __init__(self, transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - n_transformer_parameters=transformer.n_parameters, parameter_shape=torch.Size((transformer.n_parameters,)), context_shape=context_shape, **kwargs @@ -70,7 +69,6 @@ def __init__(self, transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))).invert() conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - n_transformer_parameters=transformer.n_parameters, parameter_shape=torch.Size((transformer.n_parameters,)), context_shape=context_shape, **kwargs diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py index cec05a0..90dadac 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/base.py @@ -23,7 +23,7 @@ def __init__(self, self.n_knots = n_bins + 1 @property - def n_parameters(self) -> int: + def parameter_shape_per_element(self) -> int: raise NotImplementedError def forward_inputs_inside_bounds_mask(self, x): diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py index 9dd335b..12f7cca 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/cubic.py @@ -13,11 +13,8 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_bins: int self.const = 1000 @property - def n_parameters(self) -> int: - return (2 * self.n_bins + 2) * self.n_dim - - def unflatten_conditioner_parameters(self, h: torch.Tensor): - return torch.unflatten(h, dim=-1, sizes=(*self.event_shape, 2 * self.n_bins + 2)) + def n_parameters(self) -> torch.Size: + return torch.Size((2 * self.n_bins + 2,)) def compute_spline_parameters(self, knots_x: torch.Tensor, knots_y: torch.Tensor, idx: torch.Tensor): # knots_x.shape == (n, n_knots) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py index ee9964b..e9b0d17 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear.py @@ -35,11 +35,8 @@ def compute_bin_y(self, delta): return cs * (self.max_output - self.min_output) + self.min_output @property - def n_parameters(self) -> int: - return self.n_bins * self.n_dim - - def unflatten_conditioner_parameters(self, h: torch.Tensor): - return torch.unflatten(h, dim=-1, sizes=(*self.event_shape, self.n_bins)) + def parameter_shape_per_element(self) -> torch.Size: + return torch.Size((self.n_bins,)) def forward_1d(self, x, h): assert len(x.shape) == 1 diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py index 07807dc..a80e887 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/linear_rational.py @@ -23,11 +23,9 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], boundary: fl self.eps = 5e-10 # Epsilon for numerical stability when computing forward/inverse @property - def n_parameters(self) -> int: - return 4 * self.n_bins * self.n_dim + def parameter_shape_per_element(self) -> torch.Size: + return torch.Size((4 * self.n_bins,)) - def unflatten_conditioner_parameters(self, h: torch.Tensor): - return torch.unflatten(h, dim=-1, sizes=(*self.event_shape, 4 * self.n_bins)) def compute_parameters(self, idx, knots_x, knots_y, knots_d, knots_lambda, u_w0): assert knots_x.shape == knots_y.shape == knots_d.shape diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py index f41ccff..9718197 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/spline/rational_quadratic.py @@ -38,11 +38,9 @@ def __init__(self, self.boundary_u_delta = math.log(math.expm1(1 - self.min_delta)) @property - def n_parameters(self) -> int: - return (3 * self.n_bins - 1) * self.n_dim + def parameter_shape_per_element(self) -> torch.Size: + return torch.Size((3 * self.n_bins - 1,)) - def unflatten_conditioner_parameters(self, h: torch.Tensor): - return torch.unflatten(h, dim=-1, sizes=(*self.event_shape, 3 * self.n_bins - 1)) def compute_bins(self, u, minimum, maximum): bin_sizes = torch.softmax(u, dim=-1) From 521646ed7369d246e5ce18eaf57a56f1c2c0f949 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 19:58:54 -0800 Subject: [PATCH 23/40] Fixed coupling bijections --- .../finite/autoregressive/conditioner_transforms.py | 1 + .../bijections/finite/autoregressive/layers.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 7469222..54ebb6c 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -223,6 +223,7 @@ def __init__(self, else: raise ValueError + layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape)) self.sequential = nn.Sequential(*layers) def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index fa7d2a2..b8b8185 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -51,7 +51,7 @@ def __init__(self, transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - parameter_shape=torch.Size((transformer.n_parameters,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) @@ -69,7 +69,7 @@ def __init__(self, transformer = Affine(event_shape=torch.Size((coupling_mask.transformed_event_size,))).invert() conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - parameter_shape=torch.Size((transformer.n_parameters,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) @@ -85,7 +85,7 @@ def __init__(self, transformer = Shift(event_shape=torch.Size((coupling_mask.transformed_event_size,))) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - parameter_shape=torch.Size((transformer.n_parameters,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) @@ -103,7 +103,7 @@ def __init__(self, transformer = LinearRational(event_shape=torch.Size((coupling_mask.transformed_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - parameter_shape=torch.Size((transformer.n_parameters,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) @@ -120,7 +120,7 @@ def __init__(self, transformer = RationalQuadratic(event_shape=torch.Size((coupling_mask.transformed_event_size,)), n_bins=n_bins) conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - parameter_shape=torch.Size((transformer.n_parameters,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) @@ -142,7 +142,7 @@ def __init__(self, # Each component has parameter order [a_unc, b, w_unc] conditioner_transform = FeedForward( input_event_shape=torch.Size((coupling_mask.constant_event_size,)), - parameter_shape=torch.Size((transformer.n_parameters,)), + parameter_shape=torch.Size(transformer.parameter_shape), context_shape=context_shape, **kwargs ) From 96af8373c8b89209eb2ca8a001c744a5505a1ee1 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 19:59:33 -0800 Subject: [PATCH 24/40] Add TODO --- normalizing_flows/bijections/finite/autoregressive/layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index b8b8185..ea35df9 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -16,6 +16,7 @@ from normalizing_flows.bijections.base import invert +# TODO move elementwise bijections, coupling bijections, and masked autoregressive bijections into separate files. class ElementwiseAffine(ElementwiseBijection): def __init__(self, event_shape, **kwargs): transformer = Affine(event_shape, **kwargs) From 459aafc11ced07f8ff0f49fe488d37603ce89d03 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 20:23:38 -0800 Subject: [PATCH 25/40] Towards fixing masked autoregressive bijections --- normalizing_flows/bijections/base.py | 2 +- .../finite/autoregressive/layers.py | 35 +++++++------------ .../finite/autoregressive/layers_base.py | 22 +++++++----- 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/normalizing_flows/bijections/base.py b/normalizing_flows/bijections/base.py index 64eb863..a05fe5a 100644 --- a/normalizing_flows/bijections/base.py +++ b/normalizing_flows/bijections/base.py @@ -81,7 +81,7 @@ def batch_inverse(self, x: torch.Tensor, batch_size: int, context: torch.Tensor return outputs, log_dets -def invert(bijection: Bijection) -> Bijection: +def invert(bijection): """ Swap the forward and inverse methods of the input bijection. """ diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index ea35df9..1c7e96e 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -6,6 +6,7 @@ from normalizing_flows.bijections.finite.autoregressive.layers_base import ForwardMaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection from normalizing_flows.bijections.finite.autoregressive.transformers.affine import Scale, Affine, Shift +from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.bijections.finite.autoregressive.transformers.integration.unconstrained_monotonic_neural_network import \ UnconstrainedMonotonicNeuralNetwork from normalizing_flows.bijections.finite.autoregressive.transformers.spline.linear_rational import LinearRational @@ -175,11 +176,11 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, **kwargs): - transformer = Affine(event_shape=event_shape) + transformer: ScalarTransformer = Affine(event_shape=event_shape) conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - parameter_shape=torch.Size((transformer.n_parameters,)), + parameter_shape=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) @@ -197,13 +198,11 @@ def __init__(self, context_shape: torch.Size = None, n_bins: int = 8, **kwargs): - event_size = int(torch.prod(torch.as_tensor(event_shape))) - transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) + transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters // event_size, - parameter_shape=transformer.n_parameters, + parameter_shape=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) @@ -221,13 +220,11 @@ def __init__(self, context_shape: torch.Size = None, n_bins: int = 8, **kwargs): - event_size = int(torch.prod(torch.as_tensor(event_shape))) - transformer = LinearRational(event_shape=event_shape, n_bins=n_bins) + transformer: ScalarTransformer = LinearRational(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - parameter_shape=transformer.n_parameters, - n_transformer_parameters=transformer.n_parameters // event_size, + parameter_shape=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) @@ -244,13 +241,11 @@ def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, **kwargs): - event_size = int(torch.prod(torch.as_tensor(event_shape))) - transformer = invert(Affine(event_shape=event_shape)) + transformer: ScalarTransformer = invert(Affine(event_shape=event_shape)) conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters // event_size, - parameter_shape=transformer.n_parameters, + parameter_shape=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) @@ -269,13 +264,11 @@ def __init__(self, n_bins: int = 8, **kwargs): assert n_bins >= 1 - event_size = int(torch.prod(torch.as_tensor(event_shape))) - transformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) + transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters // event_size, - parameter_shape=transformer.n_parameters, + parameter_shape=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) @@ -294,8 +287,7 @@ def __init__(self, n_hidden_layers: int = 1, hidden_dim: int = 5, **kwargs): - event_size = int(torch.prod(torch.as_tensor(event_shape))) - transformer = UnconstrainedMonotonicNeuralNetwork( + transformer: ScalarTransformer = UnconstrainedMonotonicNeuralNetwork( event_shape=event_shape, n_hidden_layers=n_hidden_layers, hidden_dim=hidden_dim @@ -303,8 +295,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - n_transformer_parameters=transformer.n_parameters // event_size, - parameter_shape=transformer.n_parameters, + parameter_shape=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 8c31196..6010c75 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -1,4 +1,4 @@ -from typing import Tuple, Optional +from typing import Tuple, Optional, Union import torch @@ -14,7 +14,7 @@ class AutoregressiveBijection(Bijection): def __init__(self, event_shape, conditioner: Optional[Conditioner], - transformer: TensorTransformer, + transformer: Union[TensorTransformer, ScalarTransformer], conditioner_transform: ConditionerTransform, **kwargs): super().__init__(event_shape=event_shape) @@ -89,9 +89,11 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class ForwardMaskedAutoregressiveBijection(AutoregressiveBijection): - def __init__(self, conditioner: Conditioner, transformer: ScalarTransformer, + def __init__(self, + conditioner: Conditioner, + transformer: ScalarTransformer, conditioner_transform: ConditionerTransform): - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer.event_shape, conditioner, transformer, conditioner_transform) def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(z, self.event_shape) @@ -112,13 +114,15 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class InverseMaskedAutoregressiveBijection(AutoregressiveBijection): - def __init__(self, conditioner: Conditioner, transformer: ScalarTransformer, + def __init__(self, + conditioner: Conditioner, + transformer: ScalarTransformer, conditioner_transform: ConditionerTransform): - super().__init__(conditioner, transformer, conditioner_transform) + super().__init__(transformer.event_shape, conditioner, transformer, conditioner_transform) self.forward_layer = ForwardMaskedAutoregressiveBijection( - conditioner, - transformer, - conditioner_transform + self.conditioner, + self.transformer, + self.conditioner_transform ) def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: From 55ab1fa866c35e441080468503d3a08471669816 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 21:08:57 -0800 Subject: [PATCH 26/40] Towards fixing masked autoregressive bijections --- .../finite/autoregressive/conditioner_transforms.py | 9 +++++---- .../bijections/finite/autoregressive/layers.py | 12 ++++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 54ebb6c..5d98cb5 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -114,7 +114,7 @@ def forward(self, x): def __init__(self, input_event_shape: torch.Size, output_event_shape: torch.Size, - parameter_shape: torch.Size, + parameter_shape_per_element: torch.Size, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2, @@ -122,9 +122,10 @@ def __init__(self, super().__init__( input_event_shape=input_event_shape, context_shape=context_shape, - parameter_shape=parameter_shape, + parameter_shape=(*output_event_shape, *parameter_shape_per_element), **kwargs ) + n_predicted_parameters_per_element = int(torch.prod(torch.as_tensor(parameter_shape_per_element))) n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) if n_hidden is None: @@ -149,8 +150,8 @@ def __init__(self, layers.extend([ self.MaskedLinear( masks[-1].shape[1], - masks[-1].shape[0] * self.n_predicted_parameters, - torch.repeat_interleave(masks[-1], self.n_predicted_parameters, dim=0) + masks[-1].shape[0] * n_predicted_parameters_per_element, + torch.repeat_interleave(masks[-1], n_predicted_parameters_per_element, dim=0) ) ]) self.sequential = nn.Sequential(*layers) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 1c7e96e..5e1e479 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -180,7 +180,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - parameter_shape=transformer.parameter_shape_per_element, + parameter_shape_per_element=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) @@ -202,7 +202,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - parameter_shape=transformer.parameter_shape_per_element, + parameter_shape_per_element=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) @@ -224,7 +224,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - parameter_shape=transformer.parameter_shape_per_element, + parameter_shape_per_element=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) @@ -245,7 +245,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - parameter_shape=transformer.parameter_shape_per_element, + parameter_shape_per_element=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) @@ -268,7 +268,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - parameter_shape=transformer.parameter_shape_per_element, + parameter_shape_per_element=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) @@ -295,7 +295,7 @@ def __init__(self, conditioner_transform = MADE( input_event_shape=event_shape, output_event_shape=event_shape, - parameter_shape=transformer.parameter_shape_per_element, + parameter_shape_per_element=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs ) From 2fc6934679124828139177258c79c5395ac90b1c Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 21:35:32 -0800 Subject: [PATCH 27/40] Simplify masked autoregressive bijections --- .../autoregressive/conditioner_transforms.py | 12 ++- .../autoregressive/conditioners/masked.py | 11 --- .../finite/autoregressive/layers.py | 95 +++---------------- .../finite/autoregressive/layers_base.py | 60 ++++++++---- 4 files changed, 61 insertions(+), 117 deletions(-) delete mode 100644 normalizing_flows/bijections/finite/autoregressive/conditioners/masked.py diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 5d98cb5..fe6fda1 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -103,6 +103,13 @@ def __init__(self, event_shape, parameter_shape, fill_value: float = None): class MADE(ConditionerTransform): + """ + Masked autoencoder for distribution estimation (MADE). + + MADE is a conditioner transform that receives as input a tensor x. It predicts parameters for the + transformer such that each dimension only depends on the previous ones. + """ + class MaskedLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, mask: torch.Tensor): super().__init__(in_features=in_features, out_features=out_features) @@ -169,9 +176,12 @@ def create_masks(n_layers, ms): masks.append(torch.as_tensor(xx >= yy, dtype=torch.float)) return masks - def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): + def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): return self.sequential(self.context_combiner(x, context)) + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): + raise NotImplementedError + class LinearMADE(MADE): def __init__(self, diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioners/masked.py b/normalizing_flows/bijections/finite/autoregressive/conditioners/masked.py deleted file mode 100644 index ee6a74f..0000000 --- a/normalizing_flows/bijections/finite/autoregressive/conditioners/masked.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - -from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner - - -class MaskedAutoregressive(Conditioner): - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor, transform, context: torch.Tensor = None) -> torch.Tensor: - return transform(x, context=context) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 5e1e479..957a0d6 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -2,8 +2,7 @@ from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import MADE, FeedForward from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import HalfSplit -from normalizing_flows.bijections.finite.autoregressive.conditioners.masked import MaskedAutoregressive -from normalizing_flows.bijections.finite.autoregressive.layers_base import ForwardMaskedAutoregressiveBijection, \ +from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection from normalizing_flows.bijections.finite.autoregressive.transformers.affine import Scale, Affine, Shift from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer @@ -171,69 +170,33 @@ def __init__(self, event_shape: torch.Size, **kwargs): super().__init__(event_shape, **kwargs, n_layers=1) -class AffineForwardMaskedAutoregressive(ForwardMaskedAutoregressiveBijection): +class AffineForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, **kwargs): transformer: ScalarTransformer = Affine(event_shape=event_shape) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - parameter_shape_per_element=transformer.parameter_shape_per_element, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) -class RQSForwardMaskedAutoregressive(ForwardMaskedAutoregressiveBijection): +class RQSForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_bins: int = 8, **kwargs): transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - parameter_shape_per_element=transformer.parameter_shape_per_element, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) -class LRSForwardMaskedAutoregressive(ForwardMaskedAutoregressiveBijection): +class LRSForwardMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, n_bins: int = 8, **kwargs): transformer: ScalarTransformer = LinearRational(event_shape=event_shape, n_bins=n_bins) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - parameter_shape_per_element=transformer.parameter_shape_per_element, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) class AffineInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): @@ -242,19 +205,7 @@ def __init__(self, context_shape: torch.Size = None, **kwargs): transformer: ScalarTransformer = invert(Affine(event_shape=event_shape)) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - parameter_shape_per_element=transformer.parameter_shape_per_element, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) class RQSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): @@ -265,22 +216,10 @@ def __init__(self, **kwargs): assert n_bins >= 1 transformer: ScalarTransformer = RationalQuadratic(event_shape=event_shape, n_bins=n_bins) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - parameter_shape_per_element=transformer.parameter_shape_per_element, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) -class UMNNMaskedAutoregressive(ForwardMaskedAutoregressiveBijection): +class UMNNMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, @@ -292,16 +231,4 @@ def __init__(self, n_hidden_layers=n_hidden_layers, hidden_dim=hidden_dim ) - conditioner_transform = MADE( - input_event_shape=event_shape, - output_event_shape=event_shape, - parameter_shape_per_element=transformer.parameter_shape_per_element, - context_shape=context_shape, - **kwargs - ) - conditioner = MaskedAutoregressive() - super().__init__( - conditioner=conditioner, - transformer=transformer, - conditioner_transform=conditioner_transform - ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index 6010c75..aa937e1 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -3,7 +3,8 @@ import torch from normalizing_flows.bijections.finite.autoregressive.conditioners.base import Conditioner, NullConditioner -from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant +from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ConditionerTransform, Constant, \ + MADE from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import CouplingMask from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer, ScalarTransformer from normalizing_flows.bijections.base import Bijection @@ -88,12 +89,41 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. return x, log_det -class ForwardMaskedAutoregressiveBijection(AutoregressiveBijection): +class MaskedAutoregressiveBijection(AutoregressiveBijection): + """ + Masked autoregressive bijection class. + + This bijection is specified with a scalar transformer. + Its conditioner is always MADE, which receives as input a tensor x with x.shape = (*batch_shape, *event_shape). + MADE outputs parameters h for the scalar transformer with + h.shape = (*batch_shape, *event_shape, *parameter_shape_per_element). + The transformer then applies the bijection elementwise. + """ + def __init__(self, - conditioner: Conditioner, + event_shape, + context_shape, transformer: ScalarTransformer, - conditioner_transform: ConditionerTransform): - super().__init__(transformer.event_shape, conditioner, transformer, conditioner_transform) + **kwargs): + conditioner_transform = MADE( + input_event_shape=event_shape, + output_event_shape=event_shape, + parameter_shape_per_element=transformer.parameter_shape_per_element, + context_shape=context_shape, + **kwargs + ) + super().__init__(transformer.event_shape, None, transformer, conditioner_transform) + + def apply_conditioner_transformer(self, inputs, context, forward: bool = True): + h = self.conditioner_transform(inputs, context) + if forward: + outputs, log_det = self.transformer.forward(inputs, h) + else: + outputs, log_det = self.transformer.inverse(inputs, h) + return outputs, log_det + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + return self.apply_conditioner_transformer(x, context, True) def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(z, self.event_shape) @@ -102,28 +132,16 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. x_flat = flatten_event(torch.clone(z), self.event_shape) for i in torch.arange(n_event_dims): x_clone = unflatten_event(torch.clone(x_flat), self.event_shape) - h = self.conditioner( - x_clone, - transform=self.conditioner_transform, - context=context - ) - tmp, log_det = self.transformer.inverse(x_clone, h) + tmp, log_det = self.apply_conditioner_transformer(x_clone, context, False) x_flat[..., i] = flatten_event(tmp, self.event_shape)[..., i] x = unflatten_event(x_flat, self.event_shape) return x, log_det class InverseMaskedAutoregressiveBijection(AutoregressiveBijection): - def __init__(self, - conditioner: Conditioner, - transformer: ScalarTransformer, - conditioner_transform: ConditionerTransform): - super().__init__(transformer.event_shape, conditioner, transformer, conditioner_transform) - self.forward_layer = ForwardMaskedAutoregressiveBijection( - self.conditioner, - self.transformer, - self.conditioner_transform - ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.forward_layer = MaskedAutoregressiveBijection(self.event_shape, self.context_shape, self.transformer) def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: return self.forward_layer.inverse(x, context=context) From 044d17f5ab76c919de8c89317f4437979ceed16d Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 21:47:35 -0800 Subject: [PATCH 28/40] Fix flat theta prediction in MADE, simplify inverse masked autoregressive bijection --- .../autoregressive/conditioner_transforms.py | 17 +++++++++++++---- .../finite/autoregressive/layers_base.py | 7 +++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index fe6fda1..bb36dfa 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -39,6 +39,12 @@ def __init__(self, global parameters are independently drawn from the standard normal distribution. """ super().__init__() + if global_parameter_mask is not None and global_parameter_mask.shape != parameter_shape: + raise ValueError( + f"Global parameter mask must have shape equal to the output parameter shape {parameter_shape}, " + f"but found {global_parameter_mask.shape}" + ) + if context_shape is None: context_combiner = Bypass(input_event_shape) elif context_shape is not None and context_combiner is None: @@ -176,11 +182,14 @@ def create_masks(n_layers, ms): masks.append(torch.as_tensor(xx >= yy, dtype=torch.float)) return masks - def predict_theta(self, x: torch.Tensor, context: torch.Tensor = None): - return self.sequential(self.context_combiner(x, context)) - def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): - raise NotImplementedError + theta = self.sequential(self.context_combiner(x, context)) + # (*b, *e, *pe) + + if self.global_parameter_mask is None: + return torch.flatten(theta, start_dim=-len(self.parameter_shape)) + else: + return theta[..., ~self.global_parameter_mask] class LinearMADE(MADE): diff --git a/normalizing_flows/bijections/finite/autoregressive/layers_base.py b/normalizing_flows/bijections/finite/autoregressive/layers_base.py index aa937e1..83fa848 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers_base.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers_base.py @@ -138,16 +138,15 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. return x, log_det -class InverseMaskedAutoregressiveBijection(AutoregressiveBijection): +class InverseMaskedAutoregressiveBijection(MaskedAutoregressiveBijection): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.forward_layer = MaskedAutoregressiveBijection(self.event_shape, self.context_shape, self.transformer) def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - return self.forward_layer.inverse(x, context=context) + return super().inverse(x, context=context) def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - return self.forward_layer.forward(z, context=context) + return super().forward(z, context=context) class ElementwiseBijection(AutoregressiveBijection): From cc141851d2e6bcaed7b7cb0281d8b423e2e0b679 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 21:59:12 -0800 Subject: [PATCH 29/40] Fix theta shape --- .../finite/autoregressive/conditioner_transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index bb36dfa..7444b33 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -76,11 +76,11 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None): # x.shape = (*batch_shape, *self.input_event_shape) # context.shape = (*batch_shape, *self.context_shape) # output.shape = (*batch_shape, *self.parameter_shape) + batch_shape = get_batch_shape(x, self.input_event_shape) if self.n_global_parameters == 0: # All parameters are predicted - return self.predict_theta_flat(x, context) + return self.predict_theta_flat(x, context).view(*batch_shape, *self.parameter_shape) else: - batch_shape = get_batch_shape(x, self.input_event_shape) if self.n_global_parameters == self.n_transformer_parameters: # All transformer parameters are learned globally output = torch.zeros(*batch_shape, *self.parameter_shape) @@ -187,7 +187,7 @@ def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): # (*b, *e, *pe) if self.global_parameter_mask is None: - return torch.flatten(theta, start_dim=-len(self.parameter_shape)) + return torch.flatten(theta, start_dim=len(theta.shape) - len(self.input_event_shape)) else: return theta[..., ~self.global_parameter_mask] From 93aa132dacd0cd402088f536ce913e6d43d829d8 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 23:05:20 -0800 Subject: [PATCH 30/40] Fix tests for conditioner transforms, fix LinearMade constructor --- .../autoregressive/conditioner_transforms.py | 84 +++++++++++-------- test/constants.py | 4 +- test/test_conditioner_transforms.py | 59 ++++++++++--- 3 files changed, 101 insertions(+), 46 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py index 7444b33..608bcc5 100644 --- a/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py +++ b/normalizing_flows/bijections/finite/autoregressive/conditioner_transforms.py @@ -1,5 +1,5 @@ import math -from typing import Tuple, Union, Type +from typing import Tuple, Union, Type, List import torch import torch.nn as nn @@ -125,10 +125,10 @@ def forward(self, x): return nn.functional.linear(x, self.weight * self.mask, self.bias) def __init__(self, - input_event_shape: torch.Size, - output_event_shape: torch.Size, - parameter_shape_per_element: torch.Size, - context_shape: torch.Size = None, + input_event_shape: Union[torch.Size, Tuple[int, ...]], + output_event_shape: Union[torch.Size, Tuple[int, ...]], + parameter_shape_per_element: Union[torch.Size, Tuple[int, ...]], + context_shape: Union[torch.Size, Tuple[int, ...]] = None, n_hidden: int = None, n_layers: int = 2, **kwargs): @@ -193,18 +193,8 @@ def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): class LinearMADE(MADE): - def __init__(self, - input_event_shape: torch.Size, - output_event_shape: torch.Size, - parameter_shape: torch.Size, - **kwargs): - super().__init__( - input_event_shape, - output_event_shape, - parameter_shape, - n_layers=1, - **kwargs - ) + def __init__(self, *args, **kwargs): + super().__init__(*args, n_layers=1, **kwargs) class FeedForward(ConditionerTransform): @@ -214,7 +204,6 @@ def __init__(self, context_shape: torch.Size = None, n_hidden: int = None, n_layers: int = 2, - hidden_linear_module: Type[nn.Module] = nn.Linear, **kwargs): super().__init__( input_event_shape=input_event_shape, @@ -226,23 +215,16 @@ def __init__(self, if n_hidden is None: n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) - # If context given, concatenate it to transform input - if context_shape is not None: - self.n_input_event_dims += self.n_context_dims - layers = [] - - # Check the one layer special case if n_layers == 1: layers.append(nn.Linear(self.n_input_event_dims, self.n_predicted_parameters)) elif n_layers > 1: - layers.extend([hidden_linear_module(self.n_input_event_dims, n_hidden), nn.Tanh()]) + layers.extend([nn.Linear(self.n_input_event_dims, n_hidden), nn.Tanh()]) for _ in range(n_layers - 2): - layers.extend([hidden_linear_module(n_hidden, n_hidden), nn.Tanh()]) + layers.extend([nn.Linear(n_hidden, n_hidden), nn.Tanh()]) layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) else: raise ValueError - layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape)) self.sequential = nn.Sequential(*layers) @@ -255,18 +237,50 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, n_layers=1) -class ResidualFeedForward(FeedForward): - class _ResidualLinearModule(nn.Module): - def __init__(self, n_in, n_out): +class ResidualFeedForward(ConditionerTransform): + class ResidualBlock(nn.Module): + def __init__(self, event_size: int, hidden_size: int, block_size: int): super().__init__() - self.linear = nn.Linear(n_in, n_out) + if block_size < 2: + raise ValueError(f"block_size must be at least 2 but found {block_size}. " + f"For block_size = 1, use the FeedForward class instead.") + layers = [] + layers.extend([nn.Linear(event_size, hidden_size), nn.ReLU()]) + for _ in range(block_size - 2): + layers.extend([nn.Linear(hidden_size, hidden_size), nn.ReLU()]) + layers.extend([nn.Linear(hidden_size, event_size)]) + self.sequential = nn.Sequential(*layers) def forward(self, x): - return x + self.linear(x) + return x + self.sequential(x) - def __init__(self, *args, **kwargs): + def __init__(self, + input_event_shape: torch.Size, + parameter_shape: torch.Size, + context_shape: torch.Size = None, + n_hidden: int = None, + n_layers: int = 3, + block_size: int = 2, + **kwargs): super().__init__( - *args, - hidden_linear_module=ResidualFeedForward._ResidualLinearModule, + input_event_shape=input_event_shape, + context_shape=context_shape, + parameter_shape=parameter_shape, **kwargs ) + + if n_hidden is None: + n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) + + if n_layers <= 2: + raise ValueError(f"Number of layers in ResidualFeedForward must be at least 3, but found {n_layers}") + + layers = [nn.Linear(self.n_input_event_dims, n_hidden), nn.ReLU()] + for _ in range(n_layers - 2): + layers.append(self.ResidualBlock(n_hidden, n_hidden, block_size)) + layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) + layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape)) + self.sequential = nn.Sequential(*layers) + + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): + return self.sequential(self.context_combiner(x, context)) diff --git a/test/constants.py b/test/constants.py index 181886f..dc4c158 100644 --- a/test/constants.py +++ b/test/constants.py @@ -5,5 +5,7 @@ 'context_shape': [None, (2,), (3,), (2, 4), (5,)], 'input_event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], 'output_event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], - 'n_predicted_parameters': [1, 2, 10, 50, 100] + 'n_predicted_parameters': [1, 2, 10, 50, 100], + 'predicted_parameter_shape': [(1,), (2,), (5,), (2, 4), (5, 2, 3)], + 'parameter_shape_per_element': [(1,), (2,), (5,), (2, 4), (5, 2, 3)], } diff --git a/test/test_conditioner_transforms.py b/test/test_conditioner_transforms.py index b7d6edd..79ac5fc 100644 --- a/test/test_conditioner_transforms.py +++ b/test/test_conditioner_transforms.py @@ -2,29 +2,68 @@ import torch from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import ( - MADE, FeedForward, LinearMADE, ResidualFeedForward, Constant, Linear + MADE, FeedForward, LinearMADE, ResidualFeedForward, Constant, Linear, ConditionerTransform ) from test.constants import __test_constants @pytest.mark.parametrize('transform_class', [ MADE, - FeedForward, LinearMADE, - ResidualFeedForward, - Linear ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('input_event_shape', __test_constants['input_event_shape']) @pytest.mark.parametrize('output_event_shape', __test_constants['output_event_shape']) -@pytest.mark.parametrize('n_predicted_parameters', __test_constants['n_predicted_parameters']) -def test_shape(transform_class, batch_shape, input_event_shape, output_event_shape, n_predicted_parameters): +@pytest.mark.parametrize('parameter_shape_per_element', __test_constants['parameter_shape_per_element']) +@pytest.mark.parametrize('context_shape', __test_constants['context_shape']) +def test_autoregressive(transform_class, + batch_shape, + input_event_shape, + output_event_shape, + parameter_shape_per_element, + context_shape): torch.manual_seed(0) x = torch.randn(size=(*batch_shape, *input_event_shape)) - transform = transform_class( + transform: ConditionerTransform = transform_class( input_event_shape=input_event_shape, output_event_shape=output_event_shape, - parameter_shape=n_predicted_parameters + parameter_shape_per_element=parameter_shape_per_element, + context_shape=context_shape, + ) + + if context_shape is not None: + c = torch.randn(size=(*batch_shape, *context_shape)) + out = transform(x, c) + else: + out = transform(x) + assert out.shape == (*batch_shape, *output_event_shape, *parameter_shape_per_element) + + +@pytest.mark.parametrize('transform_class', [ + FeedForward, + ResidualFeedForward, + Linear +]) +@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) +@pytest.mark.parametrize('input_event_shape', __test_constants['input_event_shape']) +@pytest.mark.parametrize('context_shape', __test_constants['context_shape']) +@pytest.mark.parametrize('predicted_parameter_shape', __test_constants['predicted_parameter_shape']) +def test_neural_network(transform_class, + batch_shape, + input_event_shape, + context_shape, + predicted_parameter_shape): + torch.manual_seed(0) + x = torch.randn(size=(*batch_shape, *input_event_shape)) + transform: ConditionerTransform = transform_class( + input_event_shape=input_event_shape, + context_shape=context_shape, + parameter_shape=predicted_parameter_shape ) - out = transform(x) - assert out.shape == (*batch_shape, *output_event_shape, n_predicted_parameters) + + if context_shape is not None: + c = torch.randn(size=(*batch_shape, *context_shape)) + out = transform(x, c) + else: + out = transform(x) + assert out.shape == (*batch_shape, *predicted_parameter_shape) From c8f6bb35ba59756780958c14e0a974296061c1f9 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 18 Nov 2023 23:51:12 -0800 Subject: [PATCH 31/40] Implement parameter shape method --- .../transformers/combination/base.py | 8 ++++---- .../transformers/combination/sigmoid.py | 14 +++++--------- .../unconstrained_monotonic_neural_network.py | 4 ++-- test/test_reconstruction_transformers.py | 2 +- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py index 871de2b..10b78ea 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py @@ -1,5 +1,5 @@ import torch -from typing import Tuple, List +from typing import Tuple, List, Union from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.utils import get_batch_shape @@ -12,12 +12,12 @@ def __init__(self, event_shape: torch.Size, components: List[ScalarTransformer]) self.n_components = len(self.components) @property - def n_parameters(self) -> int: - return sum([c.n_parameters for c in self.components]) + def parameter_shape_per_element(self) -> Union[torch.Size, Tuple[int, ...]]: + return (sum([c.n_parameters for c in self.components]),) @property def default_parameters(self) -> torch.Tensor: - return torch.cat([c.default_parameters for c in self.components], dim=0) + return torch.cat([c.default_parameters.ravel() for c in self.components], dim=0) def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # h.shape = (*batch_size, *event_shape, n_components * n_output_parameters) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py index 38378f0..e470def 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/sigmoid.py @@ -40,12 +40,8 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], super().__init__(event_shape) @property - def n_parameters(self) -> int: - return 3 * self.hidden_dim - - @property - def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(self.n_parameters,)) + def parameter_shape_per_element(self) -> Union[torch.Size, Tuple[int, ...]]: + return (3 * self.hidden_dim,) def extract_parameters(self, h: torch.Tensor): """ @@ -234,12 +230,12 @@ def __init__(self, self.layers = nn.ModuleList(layers) @property - def n_parameters(self) -> int: - return sum([layer.n_parameters for layer in self.layers]) + def parameter_shape_per_element(self) -> Union[torch.Size, Tuple[int, ...]]: + return (sum([layer.n_parameters for layer in self.layers]),) @property def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=(self.n_parameters,)) # TODO set up parametrization with deltas so this holds + return torch.zeros(size=self.parameter_shape) # TODO set up parametrization with deltas so this holds def split_parameters(self, h): # split parameters h into parameters for several layers diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py index c0239b8..c5de755 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py @@ -47,8 +47,8 @@ def __init__(self, self._sampled_default_params = torch.randn(size=(self.n_parameters,)) / 1000 @property - def n_parameters(self) -> int: - return self.n_input_params + self.n_output_params + self.n_hidden_params + def parameter_shape_per_element(self) -> Union[torch.Size, Tuple]: + return (self.n_input_params + self.n_output_params + self.n_hidden_params,) @property def default_parameters(self) -> torch.Tensor: diff --git a/test/test_reconstruction_transformers.py b/test/test_reconstruction_transformers.py index 8da838f..65d808f 100644 --- a/test/test_reconstruction_transformers.py +++ b/test/test_reconstruction_transformers.py @@ -26,7 +26,7 @@ def setup_transformer_data(transformer_class: ScalarTransformer, batch_shape, ev torch.manual_seed(0) transformer = transformer_class(event_shape) x = torch.randn(*batch_shape, *event_shape) - h = torch.randn(*batch_shape, *event_shape, transformer.n_parameters) + h = torch.randn(*batch_shape, *transformer.parameter_shape) return transformer, x, h From d3d2f0d4fb28eef3f8d2431c86d0788308fac43d Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 19 Nov 2023 13:40:24 -0800 Subject: [PATCH 32/40] Fix shapes in unconstrained monotonic neural network transformer --- .../autoregressive/transformers/base.py | 4 ++++ .../transformers/integration/base.py | 4 ++-- .../unconstrained_monotonic_neural_network.py | 22 +++++-------------- 3 files changed, 12 insertions(+), 18 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/base.py index 645b82d..72a96ad 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/base.py @@ -57,6 +57,10 @@ def parameter_shape_per_element(self): """ raise NotImplementedError + @property + def n_parameters_per_element(self): + return int(torch.prod(torch.as_tensor(self.parameter_shape_per_element))) + @property def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: # Scalar transformers map each element individually, so the first dimensions are the event shape diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py index 8f71343..060900b 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py @@ -63,14 +63,14 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch x.shape = (*batch_shape, *event_shape) h.shape = (*batch_shape, *event_shape, n_parameters) """ - z_flat, log_det_flat = self.forward_1d(x.view(-1), h.view(-1, self.n_parameters)) + z_flat, log_det_flat = self.forward_1d(x.view(-1), h.view(-1, self.n_parameters_per_element)) z = z_flat.view_as(x) batch_shape = get_batch_shape(x, self.event_shape) log_det = sum_except_batch(log_det_flat.view(*batch_shape, *self.event_shape), self.event_shape) return z, log_det def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - x_flat, log_det_flat = self.inverse_1d(z.view(-1), h.view(-1, self.n_parameters)) + x_flat, log_det_flat = self.inverse_1d(z.view(-1), h.view(-1, self.n_parameters_per_element)) x = x_flat.view_as(z) batch_shape = get_batch_shape(z, self.event_shape) log_det = sum_except_batch(log_det_flat.view(*batch_shape, *self.event_shape), self.event_shape) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py index c5de755..4aa43d2 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py @@ -112,28 +112,18 @@ def neural_network_forward(inputs, parameters: List[torch.Tensor]): out = 1 + torch.nn.functional.elu(out) return out - @staticmethod - def reshape_tensors(x: torch.Tensor, h: List[torch.Tensor]): - # batch_shape = get_batch_shape(x, self.event_shape) - # batch_dims = int(torch.as_tensor(batch_shape).prod()) - # event_dims = int(torch.as_tensor(self.event_shape).prod()) - flattened_dim = int(torch.as_tensor(x.shape).prod()) - x_r = x.view(flattened_dim, 1, 1) - h_r = [p.view(flattened_dim, *p.shape[-2:]) for p in h] - return x_r, h_r - def base_forward_1d(self, x: torch.Tensor, params: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - x_r, p_r = self.reshape_tensors(x, params) - integral_flat = self.integral(x_r, p_r) - log_det_flat = self.g(x_r, p_r).log() # We can apply log since g is always positive + x_r = x.view(-1, 1, 1) + integral_flat = self.integral(x_r, params) + log_det_flat = self.g(x_r, params).log() # We can apply log since g is always positive output = integral_flat.view_as(x) log_det = log_det_flat.view_as(x) return output, log_det def inverse_1d(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: params = self.compute_parameters(h) - z_r, p_r = self.reshape_tensors(z, params) - x_flat = self.inverse_1d_without_log_det(z_r, p_r) + z_r = z.view(-1, 1, 1) + x_flat = self.inverse_1d_without_log_det(z_r, params) outputs = x_flat.view_as(z) - log_det = -self.g(x_flat, p_r).log().view_as(z) + log_det = -self.g(x_flat, params).log().view_as(z) return outputs, log_det From 7393e12ab19b0213c8644c98fc155ac74aca873e Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 19 Nov 2023 13:45:35 -0800 Subject: [PATCH 33/40] Fix deep sigmoid transformer --- .../autoregressive/transformers/combination/base.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py index 10b78ea..661fda2 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/combination/base.py @@ -13,7 +13,7 @@ def __init__(self, event_shape: torch.Size, components: List[ScalarTransformer]) @property def parameter_shape_per_element(self) -> Union[torch.Size, Tuple[int, ...]]: - return (sum([c.n_parameters for c in self.components]),) + return (sum([c.n_parameters_per_element for c in self.components]),) @property def default_parameters(self) -> torch.Tensor: @@ -27,9 +27,9 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch start_index = 0 for i in range(self.n_components): component = self.components[i] - x, log_det_increment = component.forward(x, h[..., start_index:start_index + component.n_parameters]) + x, log_det_increment = component.forward(x, h[..., start_index:start_index + component.n_parameters_per_element]) log_det += log_det_increment - start_index += component.n_parameters + start_index += component.n_parameters_per_element z = x return z, log_det @@ -38,11 +38,11 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch # We assume last dim is ordered as [c1, c2, ..., ck] i.e. sequence of parameter vectors, one for each component. batch_shape = get_batch_shape(z, self.event_shape) log_det = torch.zeros(size=batch_shape) - c = self.n_parameters + c = self.n_parameters_per_element for i in range(self.n_components): component = self.components[self.n_components - i - 1] - c -= component.n_parameters - z, log_det_increment = component.inverse(z, h[..., c:c + component.n_parameters]) + c -= component.n_parameters_per_element + z, log_det_increment = component.inverse(z, h[..., c:c + component.n_parameters_per_element]) log_det += log_det_increment x = z return x, log_det From d3bd33348dc0365d78f7c203dd9f7046669f0348 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 19 Nov 2023 13:59:00 -0800 Subject: [PATCH 34/40] Fix spline tests --- test/test_spline.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_spline.py b/test/test_spline.py index eea1604..55dd02b 100644 --- a/test/test_spline.py +++ b/test/test_spline.py @@ -12,7 +12,7 @@ def test_linear_rational(): torch.manual_seed(0) x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) spline = LinearRational(event_shape=(1,)) - h = torch.randn(size=(len(x), spline.n_parameters)) + h = torch.randn(size=(len(x), *spline.parameter_shape_per_element)) z, log_det_forward = spline.forward(x, h) xr, log_det_inverse = spline.inverse(z, h) assert x.shape == z.shape == xr.shape @@ -36,7 +36,7 @@ def test_1d_spline(spline_class): [4.0], [-3.6] ]) - h = torch.randn(size=(3, 1, spline.n_parameters)) + h = torch.randn(size=(3, 1, *spline.parameter_shape_per_element)) z, log_det = spline(x, h) assert torch.all(~torch.isnan(z)) assert torch.all(~torch.isnan(log_det)) @@ -67,7 +67,7 @@ def test_2d_spline(spline_class): [4.0, 6.0], [-3.6, 0.7] ]) - h = torch.randn(size=(*batch_shape, *event_shape, spline.n_parameters)) + h = torch.randn(size=(*batch_shape, *spline.parameter_shape)) z, log_det = spline(x, h) assert torch.all(~torch.isnan(z)) assert torch.all(~torch.isnan(log_det)) @@ -95,7 +95,7 @@ def test_spline_exhaustive(spline_class, boundary: float, batch_shape, event_sha spline = spline_class(event_shape=event_shape, n_bins=8, boundary=boundary) x = torch.randn(size=(*batch_shape, *event_shape)) - h = torch.randn(size=(*batch_shape, *event_shape, spline.n_parameters)) + h = torch.randn(size=(*batch_shape, *spline.parameter_shape)) z, log_det = spline(x, h) assert torch.all(~torch.isnan(z)) assert torch.all(~torch.isnan(log_det)) @@ -119,7 +119,7 @@ def test_rq_spline(n_data, n_dim, n_bins, scale): spline = RationalQuadratic(event_shape=torch.Size((n_dim,)), n_bins=n_bins) x = torch.randn(n_data, n_dim) * scale - h = torch.randn(n_data, n_dim, spline.n_parameters) + h = torch.randn(n_data, n_dim, *spline.parameter_shape_per_element) z, log_det_forward = spline.forward(x, h) assert z.shape == x.shape From ea7b77b6e5b99f94a8ae31938d14fcba6ccd57d5 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 19 Nov 2023 14:02:41 -0800 Subject: [PATCH 35/40] Fix deep sigmoid flow constructor --- .../bijections/finite/autoregressive/architectures.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/architectures.py b/normalizing_flows/bijections/finite/autoregressive/architectures.py index 35f56a4..cca7e34 100644 --- a/normalizing_flows/bijections/finite/autoregressive/architectures.py +++ b/normalizing_flows/bijections/finite/autoregressive/architectures.py @@ -166,14 +166,14 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): class CouplingDSF(BijectiveComposition): - def __init__(self, event_shape, n_layers: int = 2, percent_global_parameters: float = 0.8, **kwargs): + def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): bijections.extend([ ReversePermutation(event_shape=event_shape), - DSCoupling(event_shape=event_shape, percent_global_parameters=percent_global_parameters) + DSCoupling(event_shape=event_shape) # TODO specify percent of global parameters ]) bijections.append(ElementwiseAffine(event_shape=event_shape)) super().__init__(event_shape, bijections, **kwargs) From d703aa6519785052c570cb4cd789c242a7ba4d24 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 19 Nov 2023 20:22:04 -0800 Subject: [PATCH 36/40] Fix UMNN and tests --- .../transformers/integration.py | 0 .../transformers/integration/base.py | 2 +- .../unconstrained_monotonic_neural_network.py | 26 +++++++++++-------- test/test_umnn.py | 14 +++++----- 4 files changed, 23 insertions(+), 19 deletions(-) delete mode 100644 normalizing_flows/bijections/finite/autoregressive/transformers/integration.py diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/integration.py b/normalizing_flows/bijections/finite/autoregressive/transformers/integration.py deleted file mode 100644 index e69de29..0000000 diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py index 060900b..af2cc78 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/base.py @@ -61,7 +61,7 @@ def inverse_1d(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, to def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ x.shape = (*batch_shape, *event_shape) - h.shape = (*batch_shape, *event_shape, n_parameters) + h.shape = (*batch_shape, *parameter_shape) """ z_flat, log_det_flat = self.forward_1d(x.view(-1), h.view(-1, self.n_parameters_per_element)) z = z_flat.view_as(x) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py index 4aa43d2..3036dc7 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/integration/unconstrained_monotonic_neural_network.py @@ -24,6 +24,12 @@ def base_forward_1d(self, x: torch.Tensor, params: List[torch.Tensor]) -> Tuple[ class UnconstrainedMonotonicNeuralNetwork(UnconstrainedMonotonicTransformer): + """ + Unconstrained monotonic neural network transformer. + + The unconstrained monotonic neural network is a neural network with positive weights and positive activation + function derivatives. These two conditions ensure its invertibility. + """ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_hidden_layers: int = 2, @@ -44,7 +50,7 @@ def __init__(self, # weight is a square matrix, bias is a vector self.n_hidden_params = (self.hidden_dim ** 2 + self.hidden_dim) * self.n_hidden_layers - self._sampled_default_params = torch.randn(size=(self.n_parameters,)) / 1000 + self._sampled_default_params = torch.randn(size=(self.n_dim, *self.parameter_shape_per_element)) / 1000 @property def parameter_shape_per_element(self) -> Union[torch.Size, Tuple]: @@ -55,31 +61,29 @@ def default_parameters(self) -> torch.Tensor: return self._sampled_default_params def compute_parameters(self, h: torch.Tensor): - batch_shape = h.shape[:-1] p0 = self.default_parameters + batch_size = h.shape[0] + n_events = batch_size // p0.shape[0] # Input layer - input_layer_defaults = pad_leading_dims(p0[:self.n_input_params], len(h.shape) - 1) + input_layer_defaults = p0[..., :self.n_input_params].repeat(n_events, 1) input_layer_deltas = h[..., :self.n_input_params] / self.const input_layer_params = input_layer_defaults + input_layer_deltas - input_layer_params = input_layer_params.view(*batch_shape, self.hidden_dim, 2) + input_layer_params = input_layer_params.view(batch_size, self.hidden_dim, 2) # Output layer - output_layer_defaults = pad_leading_dims(p0[-self.n_output_params:], len(h.shape) - 1) + output_layer_defaults = p0[..., -self.n_output_params:].repeat(n_events, 1) output_layer_deltas = h[..., -self.n_output_params:] / self.const output_layer_params = output_layer_defaults + output_layer_deltas - output_layer_params = output_layer_params.view(*batch_shape, 1, self.hidden_dim + 1) + output_layer_params = output_layer_params.view(batch_size, 1, self.hidden_dim + 1) # Hidden layers - hidden_layer_defaults = pad_leading_dims( - p0[self.n_input_params:self.n_input_params + self.n_hidden_params], - len(h.shape) - 1 - ) + hidden_layer_defaults = p0[..., self.n_input_params:self.n_input_params + self.n_hidden_params].repeat(n_events, 1) hidden_layer_deltas = h[..., self.n_input_params:self.n_input_params + self.n_hidden_params] / self.const hidden_layer_params = hidden_layer_defaults + hidden_layer_deltas hidden_layer_params = torch.chunk(hidden_layer_params, chunks=self.n_hidden_layers, dim=-1) hidden_layer_params = [ - layer.view(*batch_shape, self.hidden_dim, self.hidden_dim + 1) + layer.view(batch_size, self.hidden_dim, self.hidden_dim + 1) for layer in hidden_layer_params ] return [input_layer_params, *hidden_layer_params, output_layer_params] diff --git a/test/test_umnn.py b/test/test_umnn.py index 4aaa322..c6fe481 100644 --- a/test/test_umnn.py +++ b/test/test_umnn.py @@ -15,7 +15,7 @@ def test_umnn(batch_shape: Tuple, event_shape: Tuple): x = torch.randn(*batch_shape, *event_shape) / 100 transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*batch_shape, *event_shape, len(transformer.default_parameters)) + h = torch.randn(size=(*batch_shape, *transformer.parameter_shape)) z, log_det_forward = transformer.forward(x, h) xr, log_det_inverse = transformer.inverse(z, h) @@ -36,7 +36,7 @@ def test_umnn_forward(): x = torch.cat([x0, x1]).view(2, 1) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_forward0 = transformer.forward(x0, h) z1, log_det_forward1 = transformer.forward(x1, h) @@ -55,7 +55,7 @@ def test_umnn_inverse(): x = torch.cat([x0, x1]).view(2, 1) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_inverse0 = transformer.inverse(x0, h) z1, log_det_inverse1 = transformer.inverse(x1, h) @@ -74,7 +74,7 @@ def test_umnn_reconstruction(): x = torch.cat([x0, x1]).view(2, 1) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_forward0 = transformer.forward(x0, h) z1, log_det_forward1 = transformer.forward(x1, h) @@ -106,7 +106,7 @@ def test_umnn_forward_large_event(): x = torch.cat([x0, x1, x2]).view(3, 2) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_forward0 = transformer.forward(x0, h) z1, log_det_forward1 = transformer.forward(x1, h) @@ -127,7 +127,7 @@ def test_umnn_inverse_large_event(): x = torch.cat([x0, x1, x2]).view(3, 2) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_inverse0 = transformer.inverse(x0, h) z1, log_det_inverse1 = transformer.inverse(x1, h) @@ -148,7 +148,7 @@ def test_umnn_reconstruction_large_event(): x = torch.cat([x0, x1, x2]).view(3, 2) transformer = UnconstrainedMonotonicNeuralNetwork(event_shape=event_shape, n_hidden_layers=2, hidden_dim=20) - h = torch.randn(*event_shape, len(transformer.default_parameters)) + h = torch.randn(*event_shape, *transformer.parameter_shape_per_element) z0, log_det_forward0 = transformer.forward(x0, h) z1, log_det_forward1 = transformer.forward(x1, h) From 011e667689ab2fd79f33146aab49568521ff7d2b Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sun, 19 Nov 2023 21:06:50 -0800 Subject: [PATCH 37/40] Fixing invertible convolution --- .../transformers/convolution.py | 171 ++++-------------- test/constants.py | 2 +- test/test_reconstruction_transformers.py | 8 +- 3 files changed, 41 insertions(+), 140 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py b/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py index b65c76e..3ea7b97 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py @@ -1,65 +1,21 @@ from typing import Union, Tuple import torch -from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer +from normalizing_flows.bijections import LU +from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer +from normalizing_flows.utils import sum_except_batch, get_batch_shape -def construct_kernels_plu( - lower_elements: torch.Tensor, - upper_elements: torch.Tensor, - log_abs_diag: torch.Tensor, - sign_diag: torch.Tensor, - permutation: torch.Tensor, - k: int, - inverse: bool = False -): - """ - :param lower_elements: (b, (k ** 2 - k) // 2) - :param upper_elements: (b, (k ** 2 - k) // 2) - :param log_abs_diag: (b, k) - :param sign_diag: (k,) - :param permutation: (k, k) - :param k: kernel length - :param inverse: - :return: kernels with shape (b, k, k) - """ - - assert lower_elements.shape == upper_elements.shape - assert log_abs_diag.shape[1] == sign_diag.shape[0] - assert permutation.shape == (k, k) - assert len(log_abs_diag.shape) == 2 - assert len(lower_elements.shape) == 2 - assert lower_elements.shape[1] == (k ** 2 - k) // 2 - assert log_abs_diag.shape[1] == k - - batch_size = len(lower_elements) - - lower = torch.eye(k)[None].repeat(batch_size, 1, 1) - lower_row_idx, lower_col_idx = torch.tril_indices(k, k, offset=-1) - lower[:, lower_row_idx, lower_col_idx] = lower_elements - - upper = torch.einsum("ij,bj->bij", torch.eye(k), log_abs_diag.exp() * sign_diag) - upper_row_idx, upper_col_idx = torch.triu_indices(k, k, offset=1) - upper[:, upper_row_idx, upper_col_idx] = upper_elements - - if inverse: - if log_abs_diag.dtype == torch.float64: - lower_inv = torch.inverse(lower) - upper_inv = torch.inverse(upper) - else: - lower_inv = torch.inverse(lower.double()).type(log_abs_diag.dtype) - upper_inv = torch.inverse(upper.double()).type(log_abs_diag.dtype) - kernels = torch.einsum("bij,bjk,kl->bil", upper_inv, lower_inv, permutation.T) - else: - kernels = torch.einsum("ij,bjk,bkl->bil", permutation, lower, upper) - return kernels - - -class Invertible1x1Convolution(ScalarTransformer): +class Invertible1x1Convolution(TensorTransformer): """ Invertible 1x1 convolution. - TODO permutation may be unnecessary, maybe remove. + This transformer receives as input a batch of images x with x.shape (*batch_shape, *image_dimensions, channels) and + parameters h for an invertible linear transform of the channels + with h.shape = (*batch_shape, *image_dimensions, *parameter_shape). + Note that image_dimensions can be a shape with arbitrarily ordered dimensions (height, width). + In fact, it is not required that the image is two-dimensional. Voxels with shape (height, width, depth, channels) + are also supported, as well as tensors with more general shapes. """ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): @@ -67,96 +23,39 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): raise ValueError( f"InvertibleConvolution transformer only supports events with shape (height, width, channels)." ) - self.n_channels, self.h, self.w = event_shape - self.sign_diag = torch.sign(torch.randn(self.n_channels)) - self.permutation = torch.eye(self.n_channels)[torch.randperm(self.n_channels)] - self.const = 1000 + *self.image_dimensions, self.n_channels = event_shape + self.invertible_linear: TensorTransformer = LU(event_shape=(self.n_channels,)) super().__init__(event_shape) @property - def n_parameters(self) -> int: - return self.n_channels ** 2 + def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: + return self.invertible_linear.parameter_shape @property def default_parameters(self) -> torch.Tensor: - # Kernel matrix is identity (p=0,u=0,log_diag=0). - # Some diagonal elements are negated according to self.sign_diag. - # The matrix is then permuted. - return torch.zeros(self.n_parameters) - - def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - We parametrize K = PLU. The parameters h contain elements of L and U. - There are k ** 2 such elements. - - x.shape == (batch_size, c, h, w) - h.shape == (batch_size, k * k) + return torch.zeros(size=self.parameter_shape) - We expect each kernel to be invertible. - """ - if len(x.shape) != 4: - raise ValueError(f"Expected x to have shape (batch_size, channels, height, width), but got {x.shape}") - if len(h.shape) != 2: - raise ValueError(f"Expected h.shape to be of length 2, but got {h.shape} with length {len(h.shape)}") - if h.shape[1] != self.n_channels * self.n_channels: - raise ValueError( - f"Expected h to have shape (batch_size, kernel_height * kernel_width) = (batch_size, {self.n_channels * self.n_channels})," - f" but got {h.shape}" - ) - - h = self.default_parameters + h / self.const - - n_p_elements = (self.n_channels ** 2 - self.n_channels) // 2 - p_elements = h[..., :n_p_elements] - u_elements = h[..., n_p_elements:n_p_elements * 2] - log_diag_elements = h[..., n_p_elements * 2:] + def apply_linear(self, inputs: torch.Tensor, h: torch.Tensor, forward: bool): + batch_shape = get_batch_shape(inputs, self.event_shape) - kernels = construct_kernels_plu( - p_elements, - u_elements, - log_diag_elements, - self.sign_diag, - self.permutation, - self.n_channels, - inverse=False - ) # (b, k, k) - log_det = self.h * self.w * torch.sum(log_diag_elements, dim=-1) # (*batch_shape) + h = h / self.const + self.default_parameters[[None] * len(batch_shape)] + h_flat = torch.flatten(h, start_dim=0, end_dim=len(batch_shape) + len(self.image_dimensions)) + inputs_flat = torch.flatten(inputs, start_dim=0, end_dim=len(batch_shape) + len(self.image_dimensions)) - z = torch.zeros_like(x) - for i in range(len(x)): - z[i] = torch.conv2d(x[i], kernels[i][:, :, None, None], groups=1, stride=1, padding="same") - - return z, log_det - - def inverse(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if len(x.shape) != 4: - raise ValueError(f"Expected x to have shape (batch_size, channels, height, width), but got {x.shape}") - if len(h.shape) != 2: - raise ValueError(f"Expected h.shape to be of length 2, but got {h.shape} with length {len(h.shape)}") - if h.shape[1] != self.n_channels * self.n_channels: - raise ValueError( - f"Expected h to have shape (batch_size, kernel_height * kernel_width) = (batch_size, {self.n_channels * self.n_channels})," - f" but got {h.shape}" - ) - - h = self.default_parameters + h / self.const + # Apply linear transformation along channel dimension + if forward: + outputs_flat, log_det_flat = self.invertible_linear.forward(inputs_flat, h_flat) + else: + outputs_flat, log_det_flat = self.invertible_linear.inverse(inputs_flat, h_flat) + outputs = outputs_flat.view_as(inputs) + log_det = sum_except_batch( + log_det_flat.view(*batch_shape, *self.image_dimensions), + event_shape=self.image_dimensions + ) + return outputs, log_det - n_p_elements = (self.n_channels ** 2 - self.n_channels) // 2 - p_elements = h[..., :n_p_elements] - u_elements = h[..., n_p_elements:n_p_elements * 2] - log_diag_elements = h[..., n_p_elements * 2:] + def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self.apply_linear(x, h, forward=True) - kernels = construct_kernels_plu( - p_elements, - u_elements, - log_diag_elements, - self.sign_diag, - self.permutation, - self.n_channels, - inverse=True - ) - log_det = -self.h * self.w * torch.sum(log_diag_elements, dim=-1) # (*batch_shape) - z = torch.zeros_like(x) - for i in range(len(x)): - z[i] = torch.conv2d(x[i], kernels[i][:, :, None, None], groups=1, stride=1, padding="same") - return z, log_det + def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self.apply_linear(z, h, forward=False) diff --git a/test/constants.py b/test/constants.py index dc4c158..381c918 100644 --- a/test/constants.py +++ b/test/constants.py @@ -1,7 +1,7 @@ __test_constants = { 'batch_shape': [(1,), (2,), (5,), (2, 4), (5, 2, 3)], 'event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], - 'image_shape': [(3, 4, 4), (3, 20, 20), (3, 10, 20), (3, 200, 200), (1, 20, 20), (1, 10, 20), ], + 'image_shape': [(4, 4, 3), (20, 20, 3), (10, 20, 3), (200, 200, 3), (20, 20, 1), (10, 20, 1)], 'context_shape': [None, (2,), (3,), (2, 4), (5,)], 'input_event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], 'output_event_shape': [(2,), (3,), (2, 4), (40,), (3, 5, 2)], diff --git a/test/test_reconstruction_transformers.py b/test/test_reconstruction_transformers.py index 65d808f..2d31e59 100644 --- a/test/test_reconstruction_transformers.py +++ b/test/test_reconstruction_transformers.py @@ -119,10 +119,12 @@ def test_combination_vector_to_vector(transformer_class: ScalarTransformer, batc @pytest.mark.parametrize('image_shape', __test_constants['image_shape']) def test_convolution(batch_size: int, image_shape: Tuple): torch.manual_seed(0) - n_channels = image_shape[0] - images = torch.randn(size=(batch_size, *image_shape)) - parameters = torch.randn(size=(batch_size, n_channels ** 2)) transformer = Invertible1x1Convolution(image_shape) + + *image_dimensions, n_channels = image_shape + + images = torch.randn(size=(batch_size, *image_shape)) + parameters = torch.randn(size=(batch_size, *image_dimensions, *transformer.parameter_shape)) latent_images, log_det_forward = transformer.forward(images, parameters) reconstructed_images, log_det_inverse = transformer.inverse(latent_images, parameters) From 3fc949286d0f1ae912de7ec56eb5ededeee3b66d Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 1 Dec 2023 16:02:41 -0800 Subject: [PATCH 38/40] Fixing invertible convolution --- .../finite/autoregressive/layers.py | 4 +- .../autoregressive/transformers/base.py | 14 +++ .../transformers/linear/__init__.py | 0 .../transformers/{ => linear}/affine.py | 0 .../transformers/{ => linear}/convolution.py | 6 +- .../transformers/linear/matrix.py | 97 +++++++++++++++++++ test/test_lu_matrix_transformer.py | 21 ++++ test/test_reconstruction_transformers.py | 6 +- 8 files changed, 137 insertions(+), 11 deletions(-) create mode 100644 normalizing_flows/bijections/finite/autoregressive/transformers/linear/__init__.py rename normalizing_flows/bijections/finite/autoregressive/transformers/{ => linear}/affine.py (100%) rename normalizing_flows/bijections/finite/autoregressive/transformers/{ => linear}/convolution.py (93%) create mode 100644 normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py create mode 100644 test/test_lu_matrix_transformer.py diff --git a/normalizing_flows/bijections/finite/autoregressive/layers.py b/normalizing_flows/bijections/finite/autoregressive/layers.py index 957a0d6..b340cbe 100644 --- a/normalizing_flows/bijections/finite/autoregressive/layers.py +++ b/normalizing_flows/bijections/finite/autoregressive/layers.py @@ -1,10 +1,10 @@ import torch -from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import MADE, FeedForward +from normalizing_flows.bijections.finite.autoregressive.conditioner_transforms import FeedForward from normalizing_flows.bijections.finite.autoregressive.conditioners.coupling_masks import HalfSplit from normalizing_flows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection -from normalizing_flows.bijections.finite.autoregressive.transformers.affine import Scale, Affine, Shift +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift from normalizing_flows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from normalizing_flows.bijections.finite.autoregressive.transformers.integration.unconstrained_monotonic_neural_network import \ UnconstrainedMonotonicNeuralNetwork diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/base.py b/normalizing_flows/bijections/finite/autoregressive/transformers/base.py index 72a96ad..a24701d 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/base.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/base.py @@ -20,9 +20,23 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): super().__init__(event_shape=event_shape) def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Applies the forward transformation. + + :param torch.Tensor x: input tensor with shape (*batch_shape, *event_shape). + :param torch.Tensor h: parameter tensor with shape (*batch_shape, *parameter_shape). + :returns: output tensor with shape (*batch_shape, *event_shape). + """ raise NotImplementedError def inverse(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Applies the inverse transformation. + + :param torch.Tensor x: input tensor with shape (*batch_shape, *event_shape). + :param torch.Tensor h: parameter tensor with shape (*batch_shape, *parameter_shape). + :returns: output tensor with shape (*batch_shape, *event_shape). + """ raise NotImplementedError @property diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/linear/__init__.py b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/affine.py b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/affine.py similarity index 100% rename from normalizing_flows/bijections/finite/autoregressive/transformers/affine.py rename to normalizing_flows/bijections/finite/autoregressive/transformers/linear/affine.py diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/convolution.py similarity index 93% rename from normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py rename to normalizing_flows/bijections/finite/autoregressive/transformers/linear/convolution.py index 3ea7b97..e7d1f4c 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/convolution.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/convolution.py @@ -19,13 +19,9 @@ class Invertible1x1Convolution(TensorTransformer): """ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - if len(event_shape) != 3: - raise ValueError( - f"InvertibleConvolution transformer only supports events with shape (height, width, channels)." - ) + super().__init__(event_shape) *self.image_dimensions, self.n_channels = event_shape self.invertible_linear: TensorTransformer = LU(event_shape=(self.n_channels,)) - super().__init__(event_shape) @property def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py new file mode 100644 index 0000000..fda277f --- /dev/null +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py @@ -0,0 +1,97 @@ +from typing import Union, Tuple + +import torch + +from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer +from normalizing_flows.utils import flatten_event, unflatten_event + + +# Matrix transformers that operate on vector inputs (Ax=b) + +class LUTransformer(TensorTransformer): + """Linear transformer with LUx = y. + + It is assumed that all diagonal elements of L are 1. + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): + super().__init__(event_shape) + + def extract_matrices(self, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Extract matrices L, U from tensor h. + + :param torch.Tensor h: parameter tensor with shape (*batch_shape, *parameter_shape) + :returns: tuple with (L, U, log(diag(U))). L and U have shapes (*batch_shape, event_size, event_size), + log(diag(U)) has shape (*batch_shape, event_size). + """ + event_size = int(torch.prod(torch.as_tensor(self.event_shape))) + n_off_diag_el = (event_size ** 2 - event_size) // 2 + + u_log_diag = h[..., :event_size] + u_diag = torch.exp(u_log_diag) / 10 + 1 + u_off_diagonal_elements = h[..., event_size:event_size + n_off_diag_el] / 10 + l_off_diagonal_elements = h[..., -n_off_diag_el:] / 10 + + batch_shape = h.shape[:-len(self.parameter_shape)] + + upper = torch.zeros(size=(*batch_shape, event_size, event_size)) + upper_row_index, upper_col_index = torch.triu_indices(row=event_size, col=event_size, offset=1) + upper[..., upper_row_index, upper_col_index] = u_off_diagonal_elements + upper[..., range(event_size), range(event_size)] = u_diag + + lower = torch.zeros(size=(*batch_shape, event_size, event_size)) + lower_row_index, lower_col_index = torch.tril_indices(row=event_size, col=event_size, offset=-1) + lower[..., lower_row_index, lower_col_index] = l_off_diagonal_elements + lower[..., range(event_size), range(event_size)] = 1 # Unit diagonal + + return lower, upper, u_log_diag + + @staticmethod + def log_determinant(upper_log_diag: torch.Tensor): + """ + Computes the matrix log determinant of A = LU for each pair of matrices in a batch. + + Note: det(A) = det(LU) = det(L) * det(U) so log det(A) = log det(L) + log det(U). + We assume that L has unit diagonal, so log det(L) = 0 and can be skipped. + + :param torch.Tensor upper_log_diag: log diagonals of matrices U with shape (*batch_size, event_size). + :returns: log determinants of LU with shape (*batch_size,). + """ + # Extract the diagonals + return torch.sum(upper_log_diag, dim=-1) + + def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + lower, upper, upper_log_diag = self.extract_matrices(h) + + # Flatten inputs + x_flat = flatten_event(x, self.event_shape) # (*batch_shape, event_size) + y_flat = torch.einsum('...ij,...jk,...k->...i', lower, upper, x_flat) # y = LUx + + output = unflatten_event(y_flat, self.event_shape) + return output, self.log_determinant(upper_log_diag) + + def inverse(self, y: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + lower, upper, upper_log_diag = self.extract_matrices(h) + + # Flatten inputs + y_flat = flatten_event(y, self.event_shape)[..., None] # (*batch_shape, event_size) + z_flat = torch.linalg.solve_triangular(lower, y_flat, upper=False, unitriangular=True) # y = Lz => z = L^{-1}y + x_flat = torch.linalg.solve_triangular(upper, z_flat, upper=True, unitriangular=False) # z = Ux => x = U^{-1}z + x_flat = x_flat.squeeze(-1) + + output = unflatten_event(x_flat, self.event_shape) + return output, -self.log_determinant(upper_log_diag) + + @property + def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: + event_size = int(torch.prod(torch.as_tensor(self.event_shape))) + # Let n be the event size + # L will have (n^2 - n) / 2 parameters (we assume unit diagonal) + # U will have (n^2 - n) / 2 + n parameters + n_off_diag_el = (event_size ** 2 - event_size) // 2 + return (event_size + n_off_diag_el + n_off_diag_el,) + + @property + def default_parameters(self) -> torch.Tensor: + return torch.zeros(size=self.parameter_shape) diff --git a/test/test_lu_matrix_transformer.py b/test/test_lu_matrix_transformer.py new file mode 100644 index 0000000..ee531c0 --- /dev/null +++ b/test/test_lu_matrix_transformer.py @@ -0,0 +1,21 @@ +import torch + +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.matrix import LUTransformer + + +def test_basic(): + torch.manual_seed(0) + + batch_shape = (2, 3) + event_shape = (5, 7) + + transformer = LUTransformer(event_shape) + + x = torch.randn(size=(*batch_shape, *event_shape)) + h = torch.randn(size=(*batch_shape, *transformer.parameter_shape)) + + z, log_det_forward = transformer.forward(x, h) + x_reconstructed, log_det_inverse = transformer.inverse(z, h) + + assert torch.allclose(x, x_reconstructed, atol=1e-3), f"{torch.linalg.norm(x-x_reconstructed)}" + assert torch.allclose(log_det_forward, -log_det_inverse, atol=1e-3) diff --git a/test/test_reconstruction_transformers.py b/test/test_reconstruction_transformers.py index 2d31e59..88f01b4 100644 --- a/test/test_reconstruction_transformers.py +++ b/test/test_reconstruction_transformers.py @@ -9,10 +9,8 @@ LinearRational as LinearRationalSpline from normalizing_flows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import \ RationalQuadratic as RationalQuadraticSpline -from normalizing_flows.bijections.finite.autoregressive.transformers.convolution import Invertible1x1Convolution -from normalizing_flows.bijections.finite.autoregressive.transformers.spline.cubic import Cubic as CubicSpline -from normalizing_flows.bijections.finite.autoregressive.transformers.spline.basis import Basis as BasisSpline -from normalizing_flows.bijections.finite.autoregressive.transformers.affine import Affine, Scale, Shift +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.convolution import Invertible1x1Convolution +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.affine import Affine, Scale, Shift from normalizing_flows.bijections.finite.autoregressive.transformers.combination.sigmoid import Sigmoid, DeepSigmoid, \ DenseSigmoid, DeepDenseSigmoid from normalizing_flows.bijections.finite.autoregressive.transformers.integration.unconstrained_monotonic_neural_network import \ From 89f48c89471dfcd0dd5f2bbdcbdbd1bfa62bebbd Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 1 Dec 2023 18:29:40 -0800 Subject: [PATCH 39/40] Fix invertible convolution and test --- .../transformers/linear/convolution.py | 16 ++++++---------- .../autoregressive/transformers/linear/matrix.py | 8 +++++--- test/test_reconstruction_transformers.py | 4 ++-- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/linear/convolution.py b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/convolution.py index e7d1f4c..c649a61 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/linear/convolution.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/convolution.py @@ -3,6 +3,7 @@ from normalizing_flows.bijections import LU from normalizing_flows.bijections.finite.autoregressive.transformers.base import TensorTransformer +from normalizing_flows.bijections.finite.autoregressive.transformers.linear.matrix import LUTransformer from normalizing_flows.utils import sum_except_batch, get_batch_shape @@ -21,7 +22,7 @@ class Invertible1x1Convolution(TensorTransformer): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): super().__init__(event_shape) *self.image_dimensions, self.n_channels = event_shape - self.invertible_linear: TensorTransformer = LU(event_shape=(self.n_channels,)) + self.invertible_linear: TensorTransformer = LUTransformer(event_shape=(self.n_channels,)) @property def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: @@ -29,23 +30,18 @@ def parameter_shape(self) -> Union[torch.Size, Tuple[int, ...]]: @property def default_parameters(self) -> torch.Tensor: - return torch.zeros(size=self.parameter_shape) + return self.invertible_linear.default_parameters def apply_linear(self, inputs: torch.Tensor, h: torch.Tensor, forward: bool): batch_shape = get_batch_shape(inputs, self.event_shape) - h = h / self.const + self.default_parameters[[None] * len(batch_shape)] - h_flat = torch.flatten(h, start_dim=0, end_dim=len(batch_shape) + len(self.image_dimensions)) - inputs_flat = torch.flatten(inputs, start_dim=0, end_dim=len(batch_shape) + len(self.image_dimensions)) - # Apply linear transformation along channel dimension if forward: - outputs_flat, log_det_flat = self.invertible_linear.forward(inputs_flat, h_flat) + outputs, log_det = self.invertible_linear.forward(inputs, h) else: - outputs_flat, log_det_flat = self.invertible_linear.inverse(inputs_flat, h_flat) - outputs = outputs_flat.view_as(inputs) + outputs, log_det = self.invertible_linear.inverse(inputs, h) log_det = sum_except_batch( - log_det_flat.view(*batch_shape, *self.image_dimensions), + log_det.view(*batch_shape, *self.image_dimensions), event_shape=self.image_dimensions ) return outputs, log_det diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py index fda277f..5b0b6f9 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py @@ -28,10 +28,12 @@ def extract_matrices(self, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, event_size = int(torch.prod(torch.as_tensor(self.event_shape))) n_off_diag_el = (event_size ** 2 - event_size) // 2 - u_log_diag = h[..., :event_size] - u_diag = torch.exp(u_log_diag) / 10 + 1 + u_unc_diag = h[..., :event_size] + u_diag = torch.exp(u_unc_diag) / 10 + 1 + u_log_diag = torch.log(u_diag) + u_off_diagonal_elements = h[..., event_size:event_size + n_off_diag_el] / 10 - l_off_diagonal_elements = h[..., -n_off_diag_el:] / 10 + l_off_diagonal_elements = h[..., -n_off_diag_el:] batch_shape = h.shape[:-len(self.parameter_shape)] diff --git a/test/test_reconstruction_transformers.py b/test/test_reconstruction_transformers.py index 88f01b4..352cd69 100644 --- a/test/test_reconstruction_transformers.py +++ b/test/test_reconstruction_transformers.py @@ -136,5 +136,5 @@ def test_convolution(batch_size: int, image_shape: Tuple): assert reconstructed_images.shape == images.shape assert torch.isfinite(latent_images).all() assert torch.isfinite(reconstructed_images).all() - rec_err = torch.max(torch.abs(latent_images - reconstructed_images)) - assert torch.allclose(latent_images, reconstructed_images, atol=1e-2), f"{rec_err = }" + rec_err = torch.max(torch.abs(images - reconstructed_images)) + assert torch.allclose(images, reconstructed_images, atol=1e-2), f"{rec_err = }" From 41ba047d96bc2cf69a77cb9636440d2ea5f0cd23 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 1 Dec 2023 18:39:30 -0800 Subject: [PATCH 40/40] Reduce initial magnitudes of L elements --- .../finite/autoregressive/transformers/linear/matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py index 5b0b6f9..a41c7e5 100644 --- a/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py +++ b/normalizing_flows/bijections/finite/autoregressive/transformers/linear/matrix.py @@ -33,7 +33,7 @@ def extract_matrices(self, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, u_log_diag = torch.log(u_diag) u_off_diagonal_elements = h[..., event_size:event_size + n_off_diag_el] / 10 - l_off_diagonal_elements = h[..., -n_off_diag_el:] + l_off_diagonal_elements = h[..., -n_off_diag_el:] / 10 batch_shape = h.shape[:-len(self.parameter_shape)]