From 9d1a841ee6e22173d067126bf764475bd031bbdb Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 26 Nov 2024 16:36:52 +0100 Subject: [PATCH 1/4] Fixing Sylvester deepcopy problem --- test/test_deepcopy.py | 21 ++++--- .../bijections/finite/residual/sylvester.py | 55 ++++++++----------- torchflows/bijections/matrices.py | 3 +- torchflows/flows.py | 17 +++--- 4 files changed, 47 insertions(+), 49 deletions(-) diff --git a/test/test_deepcopy.py b/test/test_deepcopy.py index 9ca0663..2211dd7 100644 --- a/test/test_deepcopy.py +++ b/test/test_deepcopy.py @@ -1,26 +1,33 @@ from copy import deepcopy +import pytest import torch -from torchflows import RNODE, Flow +from torchflows import RNODE, Flow, Sylvester, RealNVP +from torchflows.bijections.base import invert -def test_basic(): +@pytest.mark.parametrize('flow_class', [RNODE, Sylvester, RealNVP]) +def test_basic(flow_class): torch.manual_seed(0) - b = RNODE(event_shape=(10,)) + b = flow_class(event_shape=(10,)) deepcopy(b) -def test_post_variational_fit(): +@pytest.mark.parametrize('flow_class', [RNODE, Sylvester, RealNVP]) +def test_post_variational_fit(flow_class): torch.manual_seed(0) - b = RNODE(event_shape=(10,)) + b = flow_class(event_shape=(10,)) f = Flow(b) f.variational_fit(lambda x: torch.sum(-x ** 2), n_epochs=2) deepcopy(b) -def test_post_fit(): +@pytest.mark.parametrize('flow_class', [RNODE, Sylvester, RealNVP]) +def test_post_fit(flow_class): torch.manual_seed(0) - b = RNODE(event_shape=(10,)) + b = flow_class(event_shape=(10,)) + if isinstance(b, Sylvester): + b = invert(b) f = Flow(b) f.fit(x_train=torch.randn(3, *b.event_shape), n_epochs=2) deepcopy(b) diff --git a/torchflows/bijections/finite/residual/sylvester.py b/torchflows/bijections/finite/residual/sylvester.py index d15c00a..22e3d21 100644 --- a/torchflows/bijections/finite/residual/sylvester.py +++ b/torchflows/bijections/finite/residual/sylvester.py @@ -16,60 +16,49 @@ def __init__(self, **kwargs): super().__init__(event_shape, **kwargs) self.n_dim = int(torch.prod(torch.as_tensor(event_shape))) - if m is None: m = self.n_dim // 2 if m > self.n_dim: raise ValueError - self.m = m - self.b = nn.Parameter(torch.randn(m)) - # q is implemented in subclasses + self.register_parameter('b', nn.Parameter(torch.randn(m))) self.register_module('r', UpperTriangularInvertibleMatrix(n_dim=self.m)) self.register_module('r_tilde', UpperTriangularInvertibleMatrix(n_dim=self.m)) - @property - def w(self): - r_tilde = self.r_tilde.mat() - q = self.q.mat()[:, :self.m] - return torch.einsum('...ij,...kj->...ik', r_tilde, q) - - @property - def u(self): - r = self.r.mat() - q = self.q.mat()[:, :self.m] - return torch.einsum('...ij,...jk->...ik', q, r) - - def h(self, x): - return torch.sigmoid(x) - - def h_deriv(self, x): - return torch.sigmoid(x) * (1 - torch.sigmoid(x)) - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: raise ValueError("Sylvester bijection does not support forward computation.") def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(z, self.event_shape) z_flat = torch.flatten(z, start_dim=len(batch_shape)) - u = self.u.view(*([1] * len(batch_shape)), *self.u.shape).to(z) - w = self.w.view(*([1] * len(batch_shape)), *self.w.shape).to(z) - b = self.b.view(*([1] * len(batch_shape)), *self.b.shape).to(z) - wzpb = torch.einsum('...ij,...j->...i', w, z_flat) + b # (..., m) + # Prepare parameters + q = self.q.mat()[:, :self.m] + r = self.r.mat() + r_tilde = self.r_tilde.mat() + + u = torch.einsum('...ij,...jk->...ik', q, r) + u = u.view(*([1] * len(batch_shape)), *u.shape).to(z) - x = z_flat + torch.einsum( - '...ij,...j->...i', - u, - self.h(wzpb) - ) + w = torch.einsum('...ij,...kj->...ik', r_tilde, q) + w = w.view(*([1] * len(batch_shape)), *w.shape).to(z) + b = self.b.view(*([1] * len(batch_shape)), *self.b.shape).to(z) + + # Intermediate computations + wzpb = torch.einsum('...ij,...j->...i', w, z_flat) + b # (..., m) + h = torch.sigmoid(wzpb) + h_deriv = h * (1 - h) wu = torch.einsum('...ij,...jk->...ik', w, u) # (..., m, m) + + # diag = torch.diag(h_deriv)[[None] * len(batch_shape)].repeat(*batch_shape, 1, 1) diag = torch.zeros(size=(*batch_shape, self.m, self.m)).to(z) - diag[..., range(self.m), range(self.m)] = self.h_deriv(wzpb) # (..., m, m) - _, log_det = torch.linalg.slogdet(torch.eye(self.m).to(z) + torch.einsum('...ij,...jk->...ik', diag, wu)) + diag[..., range(self.m), range(self.m)] = h_deriv # (..., m, m) + # Compute the log determinant and output + _, log_det = torch.linalg.slogdet(torch.eye(self.m).to(z) + torch.einsum('...ij,...jk->...ik', diag, wu)) + x = z_flat + torch.einsum('...ij,...j->...i', u, h) x = x.view(*batch_shape, *self.event_shape) return x, log_det diff --git a/torchflows/bijections/matrices.py b/torchflows/bijections/matrices.py index 4c0c980..6f9508b 100644 --- a/torchflows/bijections/matrices.py +++ b/torchflows/bijections/matrices.py @@ -44,9 +44,8 @@ def __init__(self, n_dim: int, unitriangular: bool = False, min_eigval: float = self.unc_diagonal_elements = None else: self.unc_diagonal_elements = nn.Parameter(torch.zeros(self.n_dim)) - self.off_diagonal_indices = torch.tril_indices(self.n_dim, self.n_dim, -1) self.min_eigval = min_eigval - + self.register_buffer('off_diagonal_indices', torch.tril_indices(self.n_dim, self.n_dim, -1)) self.register_buffer('mat_zeros', torch.zeros(size=(self.n_dim, self.n_dim))) def mat(self): diff --git a/torchflows/flows.py b/torchflows/flows.py index e6be04c..228d6db 100644 --- a/torchflows/flows.py +++ b/torchflows/flows.py @@ -341,13 +341,16 @@ def variational_fit(self, if not epoch_diverged: loss.backward() optimizer.step() - if loss < best_loss: - best_loss = loss - best_epoch = epoch - if keep_best_weights: - best_weights = deepcopy(self.state_dict()) - mean_flow_log_prob = flow_log_prob.mean() - mean_target_log_prob = target_log_prob_value.mean() + + if not epoch_diverged: + with torch.no_grad(): + if loss < best_loss: + best_loss = loss.detach() + best_epoch = epoch + if keep_best_weights: + best_weights = deepcopy(self.state_dict()) + mean_flow_log_prob = flow_log_prob.mean().detach() + mean_target_log_prob = target_log_prob_value.mean().detach() else: loss = torch.nan mean_flow_log_prob = torch.nan From 6eec3dce8cb9e075eb6a1e129e54bbdf648c30a2 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 26 Nov 2024 23:36:00 +0100 Subject: [PATCH 2/4] Rework matrix bijections --- test/test_autograd_bijections.py | 34 ++- test/test_fit.py | 16 +- test/test_reconstruction_bijections.py | 13 +- torchflows/bijections/finite/__matrix.py | 235 ++++++++++++++++++ .../finite/autoregressive/architectures.py | 4 +- torchflows/bijections/finite/linear.py | 87 ------- .../bijections/finite/matrix/__init__.py | 7 + torchflows/bijections/finite/matrix/base.py | 54 ++++ .../bijections/finite/matrix/decomposition.py | 42 ++++ .../bijections/finite/matrix/identity.py | 18 ++ .../bijections/finite/matrix/orthogonal.py | 43 ++++ .../bijections/finite/matrix/permutation.py | 37 +++ .../bijections/finite/matrix/triangular.py | 87 +++++++ torchflows/bijections/finite/matrix/util.py | 10 + .../bijections/finite/residual/sylvester.py | 113 +++++++-- torchflows/bijections/matrices.py | 175 ------------- 16 files changed, 668 insertions(+), 307 deletions(-) create mode 100644 torchflows/bijections/finite/__matrix.py delete mode 100644 torchflows/bijections/finite/linear.py create mode 100644 torchflows/bijections/finite/matrix/__init__.py create mode 100644 torchflows/bijections/finite/matrix/base.py create mode 100644 torchflows/bijections/finite/matrix/decomposition.py create mode 100644 torchflows/bijections/finite/matrix/identity.py create mode 100644 torchflows/bijections/finite/matrix/orthogonal.py create mode 100644 torchflows/bijections/finite/matrix/permutation.py create mode 100644 torchflows/bijections/finite/matrix/triangular.py create mode 100644 torchflows/bijections/finite/matrix/util.py delete mode 100644 torchflows/bijections/matrices.py diff --git a/test/test_autograd_bijections.py b/test/test_autograd_bijections.py index 81ac954..a8ec84b 100644 --- a/test/test_autograd_bijections.py +++ b/test/test_autograd_bijections.py @@ -8,8 +8,10 @@ from torchflows.bijections.finite.autoregressive.architectures import NICE, RealNVP, CouplingRQNSF, MAF, IAF, \ InverseAutoregressiveRQNSF, MaskedAutoregressiveRQNSF from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \ - LRSCoupling, LinearRQSCoupling -from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR + LRSCoupling, LinearRQSCoupling, ElementwiseRQSpline +from torchflows.bijections.finite.matrix import HouseholderOrthogonalMatrix, LowerTriangularInvertibleMatrix, \ + UpperTriangularInvertibleMatrix, IdentityMatrix, RandomPermutationMatrix, ReversePermutationMatrix, QRMatrix, \ + LUMatrix from torchflows.bijections.finite.residual.architectures import InvertibleResNet, ResFlow, ProximalResFlow from torchflows.utils import get_batch_shape from test.constants import __test_constants @@ -43,19 +45,33 @@ def assert_valid_log_probability_gradient(bijection: Bijection, x: torch.Tensor, @pytest.mark.parametrize('bijection_class', [ - LU, - ReversePermutation, ElementwiseScale, - LowerTriangular, - Orthogonal, - QR, ElementwiseAffine, - ElementwiseShift + ElementwiseShift, + ElementwiseRQSpline ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) @pytest.mark.parametrize('context_shape', __test_constants['context_shape']) -def test_linear(bijection_class: Bijection, batch_shape: Tuple, event_shape: Tuple, context_shape: Tuple): +def test_elementwise(bijection_class: Bijection, batch_shape: Tuple, event_shape: Tuple, context_shape: Tuple): + bijection, x, context = setup_data(bijection_class, batch_shape, event_shape, context_shape) + assert_valid_log_probability_gradient(bijection, x, context) + + +@pytest.mark.parametrize('bijection_class', [ + IdentityMatrix, + RandomPermutationMatrix, + ReversePermutationMatrix, + LowerTriangularInvertibleMatrix, + UpperTriangularInvertibleMatrix, + HouseholderOrthogonalMatrix, + QRMatrix, + LUMatrix, +]) +@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) +@pytest.mark.parametrize('event_shape', __test_constants['event_shape']) +@pytest.mark.parametrize('context_shape', __test_constants['context_shape']) +def test_matrix(bijection_class: Bijection, batch_shape: Tuple, event_shape: Tuple, context_shape: Tuple): bijection, x, context = setup_data(bijection_class, batch_shape, event_shape, context_shape) assert_valid_log_probability_gradient(bijection, x, context) diff --git a/test/test_fit.py b/test/test_fit.py index 0a80100..3791a81 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -6,15 +6,15 @@ MaskedAutoregressiveRQNSF from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \ ElementwiseRQSpline -from torchflows.bijections.finite.linear import LowerTriangular, LU, QR +from torchflows.bijections.finite.matrix import LowerTriangularInvertibleMatrix, LUMatrix, QRMatrix @pytest.mark.skip(reason='Takes too long, fit quality is architecture-dependent') @pytest.mark.parametrize('bijection_class', [ - LowerTriangular, + LowerTriangularInvertibleMatrix, ElementwiseScale, - LU, - QR, + LUMatrix, + QRMatrix, ElementwiseAffine, ElementwiseShift, ElementwiseRQSpline, @@ -83,9 +83,9 @@ def test_diagonal_gaussian_elementwise_scale(): @pytest.mark.skip(reason='Takes too long, fit quality is architecture-dependent') @pytest.mark.parametrize('bijection_class', [ - LowerTriangular, - LU, - QR, + LowerTriangularInvertibleMatrix, + LUMatrix, + QRMatrix, MaskedAutoregressiveRQNSF, ElementwiseRQSpline, ElementwiseAffine, @@ -102,7 +102,7 @@ def test_diagonal_gaussian_1(bijection_class): x = torch.randn(size=(n_data, n_dim)) * sigma bijection = bijection_class(event_shape=(n_dim,)) flow = Flow(bijection) - if isinstance(bijection, LowerTriangular): + if isinstance(bijection, LowerTriangularInvertibleMatrix): flow.fit(x, n_epochs=100) else: flow.fit(x, n_epochs=25) diff --git a/test/test_reconstruction_bijections.py b/test/test_reconstruction_bijections.py index 2aaaffd..f14f706 100644 --- a/test/test_reconstruction_bijections.py +++ b/test/test_reconstruction_bijections.py @@ -12,7 +12,8 @@ InverseAutoregressiveRQNSF, MaskedAutoregressiveRQNSF from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \ LRSCoupling, LinearRQSCoupling, ActNorm, DenseSigmoidalCoupling, DeepDenseSigmoidalCoupling, DeepSigmoidalCoupling -from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR +from torchflows.bijections.finite.matrix import LUMatrix, ReversePermutationMatrix, LowerTriangularInvertibleMatrix, \ + HouseholderOrthogonalMatrix, QRMatrix from torchflows.bijections.finite.residual.architectures import ResFlow, InvertibleResNet, ProximalResFlow from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock from torchflows.bijections.finite.residual.planar import Planar @@ -122,12 +123,12 @@ def assert_valid_reconstruction_continuous(bijection: ContinuousBijection, @pytest.mark.parametrize('bijection_class', [ - LU, - ReversePermutation, + LUMatrix, + ReversePermutationMatrix, ElementwiseScale, - LowerTriangular, - Orthogonal, - QR, + LowerTriangularInvertibleMatrix, + HouseholderOrthogonalMatrix, + QRMatrix, ElementwiseAffine, ElementwiseShift, ActNorm diff --git a/torchflows/bijections/finite/__matrix.py b/torchflows/bijections/finite/__matrix.py new file mode 100644 index 0000000..6fc6635 --- /dev/null +++ b/torchflows/bijections/finite/__matrix.py @@ -0,0 +1,235 @@ +# import math +# from typing import Union, Tuple +# +# import torch +# import torch.nn as nn +# +# from torchflows.bijections.base import Bijection +# from torchflows.utils import get_batch_shape +# +# +# class InvertibleMatrix(Bijection): +# """ +# Invertible matrix bijection (currently ignores context). +# """ +# +# def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): +# super().__init__(event_shape, **kwargs) +# self.register_buffer('device_buffer', torch.zeros(1)) +# +# def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: +# batch_shape = get_batch_shape(x, self.event_shape) +# x_flat = x.view(*batch_shape, -1) +# context_flat = context.view(*batch_shape, -1) if context is not None else None +# z_flat = self.project_flat(x_flat, context_flat) +# z = z_flat.view_as(x) +# log_det = self.log_det_project()[[None] * len(batch_shape)].repeat(*batch_shape) +# return z, log_det +# +# def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: +# batch_shape = get_batch_shape(z, self.event_shape) +# z_flat = z.view(*batch_shape, -1) +# context_flat = context.view(*batch_shape, -1) if context is not None else None +# x_flat = self.solve_flat(z_flat, context_flat) +# x = x_flat.view_as(z) +# log_det = -self.log_det_project()[[None] * len(batch_shape)].repeat(*batch_shape) +# return x, log_det +# +# def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: +# raise NotImplementedError +# +# def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: +# """ +# Find x in Ax = b where b is given and A is this matrix. +# +# :param b_flat: shift tensor with shape (self.n_dim,) +# :param context: +# :return: +# """ +# raise NotImplementedError +# +# def log_det_project(self) -> torch.Tensor: +# """ +# +# :return: log abs det jac of f where f(x) = Ax and A is this matrix. +# """ +# raise NotImplementedError +# +# +# class LowerTriangularInvertibleMatrix(InvertibleMatrix): +# """ +# Lower triangular matrix with strictly positive diagonal values. +# """ +# +# def __init__(self, +# event_shape: Union[torch.Size, Tuple[int, ...]], +# unitriangular: bool = False, +# min_eigval: float = 1e-3, +# **kwargs): +# super().__init__(event_shape, **kwargs) +# self.unitriangular = unitriangular +# self.min_eigval = min_eigval +# +# self.min_eigval = min_eigval +# self.log_min_eigval = math.log(min_eigval) +# +# self.off_diagonal_indices = torch.tril_indices(self.n_dim, self.n_dim, -1) +# self.register_parameter( +# 'off_diagonal_elements', +# nn.Parameter( +# torch.randn((self.n_dim ** 2 - self.n_dim) // 2) / self.n_dim ** 2 +# ) +# ) +# if not unitriangular: +# self.register_parameter('unc_diagonal_elements', nn.Parameter(torch.zeros(self.n_dim))) +# +# def compute_tril_matrix(self): +# if self.unitriangular: +# mat = torch.eye(self.n_dim) +# else: +# mat = torch.diag(torch.exp(self.unc_diagonal_elements) + self.min_eigval) +# mat[self.off_diagonal_indices[0], self.off_diagonal_indices[1]] = self.off_diagonal_elements +# return mat.to(self.device_buffer.device) +# +# def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: +# return torch.einsum('...ij,...j->...i', self.compute_tril_matrix(), x_flat) +# +# def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: +# return torch.linalg.solve_triangular( +# self.compute_tril_matrix(), +# b_flat[None].T.to(self.device_buffer.device), +# upper=False, +# unitriangular=self.unitriangular +# ).T +# +# def log_det_project(self) -> torch.Tensor: +# return torch.logaddexp( +# self.unc_diagonal_elements, +# self.logmin_eigval * torch.ones_like(self.unc_diag_elements) +# ).sum() +# +# +# class UpperTriangularInvertibleMatrix(InvertibleMatrix): +# def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): +# super().__init__(event_shape, **kwargs) +# self.lower = LowerTriangularInvertibleMatrix(event_shape=event_shape, **kwargs) +# +# def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: +# return torch.einsum('...ij,...j->...i', self.lower.compute_tril_matrix().T, x_flat) +# +# def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: +# return torch.linalg.solve_triangular( +# self.lower.compute_tril_matrix().T, +# b_flat[None].T.to(self.device_buffer.device), +# upper=True, +# unitriangular=self.unitriangular +# ).T +# +# def log_det_project(self) -> torch.Tensor: +# return self.lower.log_det_project() +# +# +# class HouseholderOrthogonalMatrix(InvertibleMatrix): +# def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_factors: int = None, **kwargs): +# super().__init__(event_shape, **kwargs) +# if n_factors is None: +# n_factors = min(5, self.n_dim) +# assert 1 <= n_factors <= self.n_dim +# +# self.v = nn.Parameter(torch.randn(n_factors, self.n_dim) / self.n_dim ** 2 + torch.eye(n_factors, self.n_dim)) +# +# def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: +# batch_shape = x_flat.shape[:-1] +# z_flat = x_flat.clone() # (*batch_shape, self.n_dim) +# for i in range(self.v.shape[0]): # Apply each Householder transformation in reverse order +# v = self.v[i] # (self.n_dim,) +# alpha = 2 * torch.einsum('i,...i->...', v, z_flat)[..., None] # (*batch_shape, 1) +# v = v[[None] * len(batch_shape)] # (1, ..., 1, self.n_dim) with len(v.shape) == len(batch_shape) + 1 +# z_flat = z_flat - alpha * v +# return z_flat +# +# def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: +# # Same code as project, just the reverse matrix order +# batch_shape = b_flat.shape[:-1] +# x_flat = b_flat.clone() # (*batch_shape, self.n_dim) +# for i in range(self.v.shape[0] - 1, -1, -1): # Apply each Householder transformation in reverse order +# v = self.v[i] # (self.n_dim,) +# alpha = 2 * torch.einsum('i,...i->...', v, x_flat)[..., None] # (*batch_shape, 1) +# v = v[[None] * len(batch_shape)] # (1, ..., 1, self.n_dim) with len(v.shape) == len(batch_shape) + 1 +# x_flat = x_flat - alpha * v +# return x_flat +# +# def log_det_project(self) -> torch.Tensor: +# return torch.tensor(0.0).to(self.device_buffer.device) +# +# +# class IdentityMatrix(InvertibleMatrix): +# def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): +# super().__init__(event_shape, **kwargs) +# +# def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: +# return x_flat +# +# def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: +# return b_flat +# +# def log_det_project(self): +# return torch.tensor(0.0).to(self.device_buffer.device) +# +# +# class PermutationMatrix(InvertibleMatrix): +# def __init__(self, +# event_shape: Union[torch.Size, Tuple[int, ...]], +# forward_permutation: torch.Tensor, +# **kwargs): +# super().__init__(event_shape, **kwargs) +# assert forward_permutation.shape == event_shape +# self.forward_permutation = forward_permutation.view(-1) +# self.inverse_permutation = torch.empty_like(self.forward_permutation) +# self.inverse_permutation[self.forward_permutation] = torch.arange(self.n_dim) +# +# +# class RandomPermutationMatrix(PermutationMatrix): +# def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): +# n_dim = int(torch.prod(torch.as_tensor(event_shape))) +# super().__init__(event_shape, forward_permutation=torch.randperm(n_dim).view(*event_shape), **kwargs) +# +# +# class ReversePermutationMatrix(PermutationMatrix): +# def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): +# n_dim = int(torch.prod(torch.as_tensor(event_shape))) +# super().__init__(event_shape, forward_permutation=torch.arange(n_dim)[::-1].view(*event_shape), **kwargs) +# +# +# class LUMatrix(InvertibleMatrix): +# def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): +# super().__init__(event_shape, **kwargs) +# self.lower = LowerTriangularInvertibleMatrix(self.n_dim, unitriangular=True, **kwargs) +# self.upper = UpperTriangularInvertibleMatrix(self.n_dim, **kwargs) +# +# def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: +# return self.lower.project_flat(self.upper.project_flat(x_flat)) +# +# def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: +# return self.upper.solve_flat(self.lower.solve_flat(b_flat)) +# +# def log_det_project(self): +# return self.upper.logdet_project() +# +# +# class QRMatrix(InvertibleMatrix): +# def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): +# super().__init__(event_shape, **kwargs) +# self.orthogonal = HouseholderOrthogonalMatrix(self.n_dim, **kwargs) +# self.upper = UpperTriangularInvertibleMatrix(self.n_dim, **kwargs) +# +# def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: +# return self.orthogonal.project_flat(self.upper.project_flat(x_flat)) +# +# def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: +# w_flat = self.orthogonal.project_inverse_flat(b_flat) +# x_flat = self.upper.solve_flat(w_flat) +# return x_flat +# +# def log_det_project(self): +# return self.upper.log_det_project() diff --git a/torchflows/bijections/finite/autoregressive/architectures.py b/torchflows/bijections/finite/autoregressive/architectures.py index 97cb0da..c5bb155 100644 --- a/torchflows/bijections/finite/autoregressive/architectures.py +++ b/torchflows/bijections/finite/autoregressive/architectures.py @@ -25,7 +25,7 @@ from torchflows.bijections.base import BijectiveComposition from torchflows.bijections.finite.autoregressive.layers_base import CouplingBijection, \ MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection -from torchflows.bijections.finite.linear import ReversePermutation +from torchflows.bijections.finite.matrix.permutation import ReversePermutationMatrix class AutoregressiveArchitecture(BijectiveComposition): @@ -45,7 +45,7 @@ def __init__(self, bijections = [ElementwiseAffine(event_shape=event_shape)] for _ in range(n_layers): if 'edge_list' not in kwargs or kwargs['edge_list'] is None: - bijections.append(ReversePermutation(event_shape=event_shape)) + bijections.append(ReversePermutationMatrix(event_shape=event_shape)) bijections.append(base_bijection(event_shape=event_shape, **kwargs)) bijections.append(ActNorm(event_shape=event_shape)) bijections.append(ElementwiseAffine(event_shape=event_shape)) diff --git a/torchflows/bijections/finite/linear.py b/torchflows/bijections/finite/linear.py deleted file mode 100644 index 26b8fcd..0000000 --- a/torchflows/bijections/finite/linear.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch - -from typing import Tuple, Union - -from torchflows.bijections.base import Bijection -from torchflows.bijections.matrices import ( - LowerTriangularInvertibleMatrix, - HouseholderOrthogonalMatrix, - InvertibleMatrix, - PermutationMatrix, - LUMatrix, - QRMatrix -) -from torchflows.utils import get_batch_shape, flatten_event, unflatten_event, flatten_batch, unflatten_batch - - -class Identity(Bijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): - super().__init__(event_shape, **kwargs) - - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - batch_shape = get_batch_shape(x, self.event_shape) - return x, torch.zeros(size=batch_shape, device=x.device) - - def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - batch_shape = get_batch_shape(z, self.event_shape) - return z, torch.zeros(size=batch_shape, device=z.device) - - -class LinearBijection(Bijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], matrix: InvertibleMatrix): - super().__init__(event_shape) - self.matrix = matrix - - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - batch_shape = get_batch_shape(x, self.event_shape) - - x = flatten_batch(flatten_event(x, self.event_shape), batch_shape) # (n_batch_dims, n_event_dims) - z = self.matrix.project(x) - z = unflatten_batch(unflatten_event(z, self.event_shape), batch_shape) - - log_det = self.matrix.log_det() + torch.zeros(size=batch_shape, device=x.device).to(x) - return z, log_det - - def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - batch_shape = get_batch_shape(z, self.event_shape) - - z = flatten_batch(flatten_event(z, self.event_shape), batch_shape) # (n_batch_dims, n_event_dims) - x = self.matrix.solve(z) - x = unflatten_batch(unflatten_event(x, self.event_shape), batch_shape) - - log_det = -self.matrix.log_det() + torch.zeros(size=batch_shape, device=z.device).to(z) - return x, log_det - - -class Permutation(LinearBijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - super().__init__(event_shape, PermutationMatrix(int(torch.prod(torch.as_tensor(event_shape))))) - - -class ReversePermutation(LinearBijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - matrix = PermutationMatrix(int(torch.prod(torch.as_tensor(event_shape)))) - matrix.forward_permutation = (matrix.n_dim - 1) - torch.arange(matrix.n_dim) - matrix.inverse_permutation = torch.empty_like(matrix.forward_permutation) - matrix.inverse_permutation[matrix.forward_permutation] = torch.arange(matrix.n_dim) - super().__init__(event_shape, matrix) - - -class LowerTriangular(LinearBijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - super().__init__(event_shape, LowerTriangularInvertibleMatrix(int(torch.prod(torch.as_tensor(event_shape))))) - - -class LU(LinearBijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - super().__init__(event_shape, LUMatrix(int(torch.prod(torch.as_tensor(event_shape))))) - - -class QR(LinearBijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - super().__init__(event_shape, QRMatrix(int(torch.prod(torch.as_tensor(event_shape))))) - - -class Orthogonal(LinearBijection): - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): - super().__init__(event_shape, HouseholderOrthogonalMatrix(int(torch.prod(torch.as_tensor(event_shape))))) diff --git a/torchflows/bijections/finite/matrix/__init__.py b/torchflows/bijections/finite/matrix/__init__.py new file mode 100644 index 0000000..01ee6f6 --- /dev/null +++ b/torchflows/bijections/finite/matrix/__init__.py @@ -0,0 +1,7 @@ +from torchflows.bijections.finite.matrix.identity import IdentityMatrix +from torchflows.bijections.finite.matrix.decomposition import LUMatrix, QRMatrix +from torchflows.bijections.finite.matrix.orthogonal import HouseholderOrthogonalMatrix +from torchflows.bijections.finite.matrix.permutation import ReversePermutationMatrix, RandomPermutationMatrix + +from torchflows.bijections.finite.matrix.triangular import UpperTriangularInvertibleMatrix, \ + LowerTriangularInvertibleMatrix diff --git a/torchflows/bijections/finite/matrix/base.py b/torchflows/bijections/finite/matrix/base.py new file mode 100644 index 0000000..286bc3c --- /dev/null +++ b/torchflows/bijections/finite/matrix/base.py @@ -0,0 +1,54 @@ +from typing import Union, Tuple + +import torch + +from torchflows.bijections.base import Bijection +from torchflows.utils import get_batch_shape + + +class InvertibleMatrix(Bijection): + """ + Invertible matrix bijection (currently ignores context). + """ + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) + self.register_buffer('device_buffer', torch.zeros(1)) + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_shape = get_batch_shape(x, self.event_shape) + x_flat = x.view(*batch_shape, -1) + context_flat = context.view(*batch_shape, -1) if context is not None else None + z_flat = self.project_flat(x_flat, context_flat) + z = z_flat.view_as(x) + log_det = self.log_det_project()[[None] * len(batch_shape)].repeat(*batch_shape, 1).squeeze(-1) + return z, log_det + + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_shape = get_batch_shape(z, self.event_shape) + z_flat = z.view(*batch_shape, -1) + context_flat = context.view(*batch_shape, -1) if context is not None else None + x_flat = self.solve_flat(z_flat, context_flat) + x = x_flat.view_as(z) + log_det = -self.log_det_project()[[None] * len(batch_shape)].repeat(*batch_shape, 1).squeeze(-1) + return x, log_det + + def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: + raise NotImplementedError + + def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + """ + Find x in Ax = b where b is given and A is this matrix. + + :param b_flat: shift tensor with shape (self.n_dim,) + :param context: + :return: + """ + raise NotImplementedError + + def log_det_project(self) -> torch.Tensor: + """ + + :return: log abs det jac of f where f(x) = Ax and A is this matrix. + """ + raise NotImplementedError diff --git a/torchflows/bijections/finite/matrix/decomposition.py b/torchflows/bijections/finite/matrix/decomposition.py new file mode 100644 index 0000000..2343582 --- /dev/null +++ b/torchflows/bijections/finite/matrix/decomposition.py @@ -0,0 +1,42 @@ +from typing import Union, Tuple + +import torch + +from torchflows.bijections.finite.matrix.base import InvertibleMatrix +from torchflows.bijections.finite.matrix.orthogonal import HouseholderOrthogonalMatrix +from torchflows.bijections.finite.matrix.triangular import LowerTriangularInvertibleMatrix, \ + UpperTriangularInvertibleMatrix + + +class LUMatrix(InvertibleMatrix): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) + self.lower = LowerTriangularInvertibleMatrix(self.n_dim, unitriangular=True, **kwargs) + self.upper = UpperTriangularInvertibleMatrix(self.n_dim, **kwargs) + + def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: + return self.lower.project_flat(self.upper.project_flat(x_flat)) + + def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + return self.upper.solve_flat(self.lower.solve_flat(b_flat)) + + def log_det_project(self): + return self.upper.log_det_project() + + +class QRMatrix(InvertibleMatrix): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) + self.orthogonal = HouseholderOrthogonalMatrix(self.n_dim, **kwargs) + self.upper = UpperTriangularInvertibleMatrix(self.n_dim, **kwargs) + + def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: + return self.orthogonal.project_flat(self.upper.project_flat(x_flat)) + + def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + w_flat = self.orthogonal.solve_flat(b_flat) + x_flat = self.upper.solve_flat(w_flat) + return x_flat + + def log_det_project(self): + return self.upper.log_det_project() diff --git a/torchflows/bijections/finite/matrix/identity.py b/torchflows/bijections/finite/matrix/identity.py new file mode 100644 index 0000000..7c13baa --- /dev/null +++ b/torchflows/bijections/finite/matrix/identity.py @@ -0,0 +1,18 @@ +from typing import Union, Tuple + +import torch +from torchflows.bijections.finite.matrix.base import InvertibleMatrix + + +class IdentityMatrix(InvertibleMatrix): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) + + def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: + return x_flat + + def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + return b_flat + + def log_det_project(self): + return torch.zeros(1).to(self.device_buffer.device) \ No newline at end of file diff --git a/torchflows/bijections/finite/matrix/orthogonal.py b/torchflows/bijections/finite/matrix/orthogonal.py new file mode 100644 index 0000000..b0c9861 --- /dev/null +++ b/torchflows/bijections/finite/matrix/orthogonal.py @@ -0,0 +1,43 @@ +from typing import Union, Tuple + +import torch +import torch.nn as nn + +from torchflows.bijections.finite.matrix.base import InvertibleMatrix + + +class HouseholderOrthogonalMatrix(InvertibleMatrix): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_factors: int = None, **kwargs): + super().__init__(event_shape, **kwargs) + if n_factors is None: + n_factors = min(5, self.n_dim) + assert 1 <= n_factors <= self.n_dim + + self.v = nn.Parameter(torch.randn(n_factors, self.n_dim) / self.n_dim ** 2 + torch.eye(n_factors, self.n_dim)) + + def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: + batch_shape = x_flat.shape[:-1] + z_flat = x_flat.clone() # (*batch_shape, self.n_dim) + for i in range(self.v.shape[0]): # Apply each Householder transformation in reverse order + v = self.v[i] # (self.n_dim,) + alpha = 2 * torch.einsum('i,...i->...', v, z_flat)[..., None] # (*batch_shape, 1) + v = v[[None] * len(batch_shape)] # (1, ..., 1, self.n_dim) with len(v.shape) == len(batch_shape) + 1 + z_flat = z_flat - alpha * v + return z_flat + + def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + # Same code as project, just the reverse matrix order + batch_shape = b_flat.shape[:-1] + x_flat = b_flat.clone() # (*batch_shape, self.n_dim) + for i in range(self.v.shape[0] - 1, -1, -1): # Apply each Householder transformation in reverse order + v = self.v[i] # (self.n_dim,) + alpha = 2 * torch.einsum('i,...i->...', v, x_flat)[..., None] # (*batch_shape, 1) + v = v[[None] * len(batch_shape)] # (1, ..., 1, self.n_dim) with len(v.shape) == len(batch_shape) + 1 + x_flat = x_flat - alpha * v + return x_flat + + def log_det_project(self) -> torch.Tensor: + return torch.tensor(0.0).to(self.device_buffer.device) + + def __matmul__(self, other: torch.Tensor) -> torch.Tensor: + return self.project_flat(other) \ No newline at end of file diff --git a/torchflows/bijections/finite/matrix/permutation.py b/torchflows/bijections/finite/matrix/permutation.py new file mode 100644 index 0000000..cdce84c --- /dev/null +++ b/torchflows/bijections/finite/matrix/permutation.py @@ -0,0 +1,37 @@ +from typing import Union, Tuple + +import torch + +from torchflows.bijections.finite.matrix.base import InvertibleMatrix + + +class PermutationMatrix(InvertibleMatrix): + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + forward_permutation: torch.Tensor, + **kwargs): + super().__init__(event_shape, **kwargs) + assert forward_permutation.shape == event_shape + self.forward_permutation = forward_permutation.view(-1) + self.inverse_permutation = torch.empty_like(self.forward_permutation) + self.inverse_permutation[self.forward_permutation] = torch.arange(self.n_dim) + + def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: + return x_flat[..., self.forward_permutation] + + def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + return b_flat[..., self.inverse_permutation] + + def log_det_project(self) -> torch.Tensor: + return torch.zeros(1).to(self.device_buffer.device) + +class RandomPermutationMatrix(PermutationMatrix): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + n_dim = int(torch.prod(torch.as_tensor(event_shape))) + super().__init__(event_shape, forward_permutation=torch.randperm(n_dim).view(*event_shape), **kwargs) + + +class ReversePermutationMatrix(PermutationMatrix): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + n_dim = int(torch.prod(torch.as_tensor(event_shape))) + super().__init__(event_shape, forward_permutation=torch.arange(n_dim - 1, -1, -1).view(*event_shape), **kwargs) diff --git a/torchflows/bijections/finite/matrix/triangular.py b/torchflows/bijections/finite/matrix/triangular.py new file mode 100644 index 0000000..1623689 --- /dev/null +++ b/torchflows/bijections/finite/matrix/triangular.py @@ -0,0 +1,87 @@ +from typing import Union, Tuple + +import torch +import math +import torch.nn as nn + +from torchflows.bijections.finite.matrix.base import InvertibleMatrix + + +class LowerTriangularInvertibleMatrix(InvertibleMatrix): + """ + Lower triangular matrix with strictly positive diagonal values. + """ + + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + unitriangular: bool = False, + min_eigval: float = 1e-3, + **kwargs): + super().__init__(event_shape, **kwargs) + self.unitriangular = unitriangular + self.min_eigval = min_eigval + + self.min_eigval = min_eigval + self.log_min_eigval = math.log(min_eigval) + + self.off_diagonal_indices = torch.tril_indices(self.n_dim, self.n_dim, -1) + self.register_parameter( + 'off_diagonal_elements', + nn.Parameter( + torch.randn((self.n_dim ** 2 - self.n_dim) // 2) / self.n_dim ** 2 + ) + ) + if not unitriangular: + self.register_parameter('unc_diag_elements', nn.Parameter(torch.zeros(self.n_dim))) + + def compute_matrix(self): + if self.unitriangular: + mat = torch.eye(self.n_dim) + else: + mat = torch.diag(torch.exp(self.unc_diag_elements) + self.min_eigval) + mat[self.off_diagonal_indices[0], self.off_diagonal_indices[1]] = self.off_diagonal_elements + return mat.to(self.device_buffer.device) + + def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: + return torch.einsum('...ij,...j->...i', self.compute_matrix(), x_flat) + + def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + b_flat_batch = b_flat.view(-1, b_flat.shape[-1]) + x_flat_batch = torch.linalg.solve_triangular( + self.compute_matrix(), + b_flat_batch.T.to(self.device_buffer.device), + upper=False, + unitriangular=self.unitriangular + ).T + return x_flat_batch.view_as(b_flat_batch) + + def log_det_project(self) -> torch.Tensor: + return torch.logaddexp( + self.unc_diag_elements, + self.log_min_eigval * torch.ones_like(self.unc_diag_elements) + ).sum() + + +class UpperTriangularInvertibleMatrix(InvertibleMatrix): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) + self.lower = LowerTriangularInvertibleMatrix(event_shape=event_shape, **kwargs) + + def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: + return torch.einsum('...ij,...j->...i', self.lower.compute_matrix().T, x_flat) + + def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + b_flat_batch = b_flat.view(-1, b_flat.shape[-1]) + x_flat_batch = torch.linalg.solve_triangular( + self.lower.compute_matrix().T, + b_flat_batch.T.to(self.device_buffer.device), + upper=True, + unitriangular=self.lower.unitriangular + ).T + return x_flat_batch.view_as(b_flat_batch) + + def compute_matrix(self): + return self.lower.compute_matrix().T + + def log_det_project(self) -> torch.Tensor: + return self.lower.log_det_project() \ No newline at end of file diff --git a/torchflows/bijections/finite/matrix/util.py b/torchflows/bijections/finite/matrix/util.py new file mode 100644 index 0000000..b815433 --- /dev/null +++ b/torchflows/bijections/finite/matrix/util.py @@ -0,0 +1,10 @@ +import torch + +from torchflows.bijections.finite.matrix import HouseholderOrthogonalMatrix + + +def matmul_with_householder(a: torch.Tensor, q: HouseholderOrthogonalMatrix): + product = a + for i in range(len(q.v)): + product = product - 2 * q.v[i][:, None] * torch.matmul(q.v[i][None], product) + return product diff --git a/torchflows/bijections/finite/residual/sylvester.py b/torchflows/bijections/finite/residual/sylvester.py index 22e3d21..9029967 100644 --- a/torchflows/bijections/finite/residual/sylvester.py +++ b/torchflows/bijections/finite/residual/sylvester.py @@ -2,10 +2,13 @@ import torch import torch.nn as nn +from Cython.Shadow import returns +from torchflows.bijections.finite.matrix import UpperTriangularInvertibleMatrix, IdentityMatrix, \ + HouseholderOrthogonalMatrix +from torchflows.bijections.finite.matrix.permutation import PermutationMatrix, RandomPermutationMatrix +from torchflows.bijections.finite.matrix.util import matmul_with_householder from torchflows.bijections.finite.residual.base import ClassicResidualBijection -from torchflows.bijections.matrices import UpperTriangularInvertibleMatrix, HouseholderOrthogonalMatrix, \ - IdentityMatrix, PermutationMatrix from torchflows.utils import get_batch_shape @@ -23,8 +26,16 @@ def __init__(self, self.m = m self.register_parameter('b', nn.Parameter(torch.randn(m))) - self.register_module('r', UpperTriangularInvertibleMatrix(n_dim=self.m)) - self.register_module('r_tilde', UpperTriangularInvertibleMatrix(n_dim=self.m)) + self.register_module('r', UpperTriangularInvertibleMatrix((m,))) + self.register_module('r_tilde', UpperTriangularInvertibleMatrix((m,))) + + def compute_u(self): + # u = Q * R + raise NotImplementedError + + def compute_w(self): + # w = R_tilde * Q.T + raise NotImplementedError def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: raise ValueError("Sylvester bijection does not support forward computation.") @@ -34,20 +45,11 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. z_flat = torch.flatten(z, start_dim=len(batch_shape)) # Prepare parameters - q = self.q.mat()[:, :self.m] - r = self.r.mat() - r_tilde = self.r_tilde.mat() - - u = torch.einsum('...ij,...jk->...ik', q, r) - u = u.view(*([1] * len(batch_shape)), *u.shape).to(z) - - w = torch.einsum('...ij,...kj->...ik', r_tilde, q) - w = w.view(*([1] * len(batch_shape)), *w.shape).to(z) - - b = self.b.view(*([1] * len(batch_shape)), *self.b.shape).to(z) + u = self.compute_u() + w = self.compute_w() # Intermediate computations - wzpb = torch.einsum('...ij,...j->...i', w, z_flat) + b # (..., m) + wzpb = torch.einsum('ij,...j->...i', w, z_flat) + self.b[[None] * len(batch_shape)] # (..., m) h = torch.sigmoid(wzpb) h_deriv = h * (1 - h) wu = torch.einsum('...ij,...jk->...ik', w, u) # (..., m, m) @@ -67,19 +69,90 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class HouseholderSylvester(BaseSylvester): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): super().__init__(event_shape, **kwargs) - self.register_module('q', HouseholderOrthogonalMatrix(n_dim=self.n_dim, n_factors=self.m)) + self.register_module('q', HouseholderOrthogonalMatrix(event_shape, n_factors=self.m)) + + def compute_u(self): + return self.q.project_flat(self.r.compute_matrix()) + + def compute_w(self): + # No need to transpose as Q is symmetric + return matmul_with_householder(self.r_tilde.compute_matrix(), self.q) class IdentitySylvester(BaseSylvester): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): super().__init__(event_shape, **kwargs) - self.register_module('q', IdentityMatrix(n_dim=self.n_dim)) + self.register_module('q', IdentityMatrix(event_shape)) + def compute_u(self): + r = self.r.compute_matrix() + return torch.concat([r, torch.zeros_like(r)], dim=-2) -Sylvester = IdentitySylvester + def compute_w(self): + rt = self.r_tilde.compute_matrix() + return torch.concat([rt, torch.zeros_like(rt)], dim=-1) class PermutationSylvester(BaseSylvester): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): super().__init__(event_shape, **kwargs) - self.register_module('q', PermutationMatrix(n_dim=self.n_dim)) + self.register_module('q', RandomPermutationMatrix(event_shape)) + + def compute_u(self): + return self.r.compute_matrix() + + def compute_w(self): + return self.r_tilde.compute_matrix() + + +# class IdentitySylvester(ClassicResidualBijection): +# def __init__(self, +# event_shape: Union[torch.Size, Tuple[int, ...]], +# m: int = None, +# **kwargs): +# super().__init__(event_shape, **kwargs) +# self.n_dim = int(torch.prod(torch.as_tensor(event_shape))) +# if m is None: +# m = self.n_dim // 2 +# if m > self.n_dim: +# raise ValueError +# self.m = m +# +# self.register_parameter('b', nn.Parameter(torch.randn(m))) +# self.register_module('r', UpperTriangularInvertibleMatrix(event_shape)) +# self.register_module('r_tilde', UpperTriangularInvertibleMatrix(event_shape)) +# +# def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: +# raise ValueError("Sylvester bijection does not support forward computation.") +# +# def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: +# batch_shape = get_batch_shape(z, self.event_shape) +# z_flat = torch.flatten(z, start_dim=len(batch_shape)) +# +# # Prepare parameters +# q = torch.eye(self.n_dim, self.m) +# +# u = torch.einsum('...ij,...jk->...ik', q, self.r) +# u = u.view(*([1] * len(batch_shape)), *u.shape).to(z) +# w = torch.concat([self.r_tilde, torch.zeros_like(self.r_tilde)], dim=-1) +# w = w.view(*([1] * len(batch_shape)), *w.shape).to(z) +# b = self.b.view(*([1] * len(batch_shape)), *self.b.shape).to(z) +# +# # Intermediate computations +# wzpb = torch.einsum('...ij,...j->...i', w, z_flat) + b # (..., m) +# h = torch.sigmoid(wzpb) +# h_deriv = h * (1 - h) +# wu = torch.einsum('...ij,...jk->...ik', w, u) # (..., m, m) +# +# # diag = torch.diag(h_deriv)[[None] * len(batch_shape)].repeat(*batch_shape, 1, 1) +# diag = torch.zeros(size=(*batch_shape, self.m, self.m)).to(z) +# diag[..., range(self.m), range(self.m)] = h_deriv # (..., m, m) +# +# # Compute the log determinant and output +# _, log_det = torch.linalg.slogdet(torch.eye(self.m).to(z) + torch.einsum('...ij,...jk->...ik', diag, wu)) +# x = z_flat + torch.einsum('...ij,...j->...i', u, h) +# x = x.view(*batch_shape, *self.event_shape) +# +# return x, log_det + +Sylvester = IdentitySylvester diff --git a/torchflows/bijections/matrices.py b/torchflows/bijections/matrices.py deleted file mode 100644 index 6f9508b..0000000 --- a/torchflows/bijections/matrices.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.nn as nn - - -class InvertibleMatrix(nn.Module): - def __init__(self, n_dim: int, **kwargs): - super().__init__() - self.n_dim = n_dim - - def mat(self): - raise NotImplementedError - - def log_det(self): - raise NotImplementedError - - def project(self, x): - # Compute z = mat @ x - return torch.einsum('ij,...j->...i', self.mat(), x) - - def solve(self, z): - # Compute x = mat^-1 z - raise NotImplementedError - - -class LowerTriangularInvertibleMatrix(InvertibleMatrix): - """ - Lower triangular matrix with strictly positive diagonal values. - """ - - def __init__(self, n_dim: int, unitriangular: bool = False, min_eigval: float = 1e-3): - """ - - :param n_dim: - :param unitriangular: - :param min_eigval: minimum eigenvalue. This is added to - """ - super().__init__(n_dim) - self.unitriangular = unitriangular - - n_off_diagonal_elements = (self.n_dim ** 2 - self.n_dim) // 2 - initial_off_diagonal_elements = torch.randn(n_off_diagonal_elements) / self.n_dim ** 2 - self.off_diagonal_elements = nn.Parameter(initial_off_diagonal_elements) - if unitriangular: - self.unc_diagonal_elements = None - else: - self.unc_diagonal_elements = nn.Parameter(torch.zeros(self.n_dim)) - self.min_eigval = min_eigval - self.register_buffer('off_diagonal_indices', torch.tril_indices(self.n_dim, self.n_dim, -1)) - self.register_buffer('mat_zeros', torch.zeros(size=(self.n_dim, self.n_dim))) - - def mat(self): - mat = self.mat_zeros - mat[range(self.n_dim), range(self.n_dim)] = self.compute_diagonal_elements() - mat[self.off_diagonal_indices[0], self.off_diagonal_indices[1]] = self.off_diagonal_elements - return mat - - def compute_diagonal_elements(self): - if self.unitriangular: - return torch.ones(self.n_dim) - else: - return torch.exp(self.unc_diagonal_elements) + self.min_eigval - - def log_det(self): - return torch.sum(torch.log(self.compute_diagonal_elements())) - - def solve(self, z): - return torch.linalg.solve_triangular(self.mat(), z.T, upper=False, unitriangular=self.unitriangular).T - - -class UpperTriangularInvertibleMatrix(InvertibleMatrix): - def __init__(self, n_dim: int, **kwargs): - super().__init__(n_dim) - self.lower = LowerTriangularInvertibleMatrix(n_dim=n_dim, **kwargs) - - def mat(self): - return self.lower.mat().T - - def log_det(self): - return self.lower.log_det() - - def solve(self, z): - return torch.linalg.solve_triangular(self.mat(), z.T, upper=True, unitriangular=self.lower.unitriangular).T - - -class HouseholderOrthogonalMatrix(InvertibleMatrix): - # TODO more efficient project and solve? - def __init__(self, n_dim: int, n_factors: int = None): - super().__init__(n_dim=n_dim) - if n_factors is None: - n_factors = min(5, self.n_dim) - assert 1 <= n_factors <= self.n_dim - self.v = nn.Parameter(torch.randn(n_factors, self.n_dim) / self.n_dim ** 2 + torch.eye(n_factors, self.n_dim)) - self.tau = torch.full((n_factors,), fill_value=2.0) - - def mat(self): - v_outer = torch.einsum('fi,fj->fij', self.v, self.v) - v_norms_squared = torch.linalg.norm(self.v, dim=1).view(-1, 1, 1) ** 2 - h = (torch.eye(self.n_dim)[None].to(v_outer) - 2 * (v_outer / v_norms_squared)) - return torch.linalg.multi_dot(list(h)) - - def log_det(self): - return 0.0 - - def solve(self, z): - return (self.mat().T @ z.T).T - - -class IdentityMatrix(InvertibleMatrix): - def __init__(self, n_dim: int, **kwargs): - super().__init__(n_dim, **kwargs) - self.register_buffer('_mat', torch.eye(self.n_dim)) - - def mat(self): - return self._mat - - def log_det(self): - return 0.0 - - def project(self, x): - return x - - def solve(self, z): - return z - - -class PermutationMatrix(InvertibleMatrix): - def __init__(self, n_dim: int, **kwargs): - super().__init__(n_dim, **kwargs) - self.forward_permutation = torch.randperm(n_dim) - self.inverse_permutation = torch.empty_like(self.forward_permutation) - self.inverse_permutation[self.forward_permutation] = torch.arange(n_dim) - - def mat(self): - return torch.eye(self.n_dim)[self.forward_permutation] - - def log_det(self): - return 0.0 - - def project(self, x): - return x[..., self.forward_permutation] - - def solve(self, z): - return z[..., self.inverse_permutation] - - -class LUMatrix(InvertibleMatrix): - def __init__(self, n_dim: int, **kwargs): - super().__init__(n_dim) - self.lower = LowerTriangularInvertibleMatrix(n_dim, unitriangular=True, **kwargs) - self.upper = UpperTriangularInvertibleMatrix(n_dim, **kwargs) - - def mat(self): - return self.lower.mat() @ self.upper.mat() - - def log_det(self): - return self.upper.log_det() - - def solve(self, z): - return self.upper.solve(self.lower.solve(z)) - - -class QRMatrix(InvertibleMatrix): - def __init__(self, n_dim: int, **kwargs): - super().__init__(n_dim) - self.orthogonal = HouseholderOrthogonalMatrix(self.n_dim, **kwargs) - self.upper = UpperTriangularInvertibleMatrix(n_dim, **kwargs) - - def mat(self): - return self.orthogonal.mat() @ self.upper.mat() - - def solve(self, z): - return self.upper.solve(self.orthogonal.solve(z)) - - def log_det(self): - return self.upper.log_det() From 3273ce804db44f1f0466a808dc40d45cfe4d6a9b Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Wed, 27 Nov 2024 00:07:49 +0100 Subject: [PATCH 3/4] Rename Householder matrix object, fix project and solve methods --- test/test_autograd_bijections.py | 4 +- test/test_reconstruction_bijections.py | 4 +- .../bijections/finite/matrix/__init__.py | 2 +- .../bijections/finite/matrix/decomposition.py | 4 +- .../bijections/finite/matrix/orthogonal.py | 37 +++++++++---------- torchflows/bijections/finite/matrix/util.py | 4 +- .../bijections/finite/residual/sylvester.py | 4 +- 7 files changed, 29 insertions(+), 30 deletions(-) diff --git a/test/test_autograd_bijections.py b/test/test_autograd_bijections.py index a8ec84b..68ef9ed 100644 --- a/test/test_autograd_bijections.py +++ b/test/test_autograd_bijections.py @@ -9,7 +9,7 @@ InverseAutoregressiveRQNSF, MaskedAutoregressiveRQNSF from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \ LRSCoupling, LinearRQSCoupling, ElementwiseRQSpline -from torchflows.bijections.finite.matrix import HouseholderOrthogonalMatrix, LowerTriangularInvertibleMatrix, \ +from torchflows.bijections.finite.matrix import HouseholderProductMatrix, LowerTriangularInvertibleMatrix, \ UpperTriangularInvertibleMatrix, IdentityMatrix, RandomPermutationMatrix, ReversePermutationMatrix, QRMatrix, \ LUMatrix from torchflows.bijections.finite.residual.architectures import InvertibleResNet, ResFlow, ProximalResFlow @@ -64,7 +64,7 @@ def test_elementwise(bijection_class: Bijection, batch_shape: Tuple, event_shape ReversePermutationMatrix, LowerTriangularInvertibleMatrix, UpperTriangularInvertibleMatrix, - HouseholderOrthogonalMatrix, + HouseholderProductMatrix, QRMatrix, LUMatrix, ]) diff --git a/test/test_reconstruction_bijections.py b/test/test_reconstruction_bijections.py index f14f706..8fa8da3 100644 --- a/test/test_reconstruction_bijections.py +++ b/test/test_reconstruction_bijections.py @@ -13,7 +13,7 @@ from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \ LRSCoupling, LinearRQSCoupling, ActNorm, DenseSigmoidalCoupling, DeepDenseSigmoidalCoupling, DeepSigmoidalCoupling from torchflows.bijections.finite.matrix import LUMatrix, ReversePermutationMatrix, LowerTriangularInvertibleMatrix, \ - HouseholderOrthogonalMatrix, QRMatrix + HouseholderProductMatrix, QRMatrix from torchflows.bijections.finite.residual.architectures import ResFlow, InvertibleResNet, ProximalResFlow from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock from torchflows.bijections.finite.residual.planar import Planar @@ -127,7 +127,7 @@ def assert_valid_reconstruction_continuous(bijection: ContinuousBijection, ReversePermutationMatrix, ElementwiseScale, LowerTriangularInvertibleMatrix, - HouseholderOrthogonalMatrix, + HouseholderProductMatrix, QRMatrix, ElementwiseAffine, ElementwiseShift, diff --git a/torchflows/bijections/finite/matrix/__init__.py b/torchflows/bijections/finite/matrix/__init__.py index 01ee6f6..455ae91 100644 --- a/torchflows/bijections/finite/matrix/__init__.py +++ b/torchflows/bijections/finite/matrix/__init__.py @@ -1,6 +1,6 @@ from torchflows.bijections.finite.matrix.identity import IdentityMatrix from torchflows.bijections.finite.matrix.decomposition import LUMatrix, QRMatrix -from torchflows.bijections.finite.matrix.orthogonal import HouseholderOrthogonalMatrix +from torchflows.bijections.finite.matrix.orthogonal import HouseholderProductMatrix from torchflows.bijections.finite.matrix.permutation import ReversePermutationMatrix, RandomPermutationMatrix from torchflows.bijections.finite.matrix.triangular import UpperTriangularInvertibleMatrix, \ diff --git a/torchflows/bijections/finite/matrix/decomposition.py b/torchflows/bijections/finite/matrix/decomposition.py index 2343582..35efd30 100644 --- a/torchflows/bijections/finite/matrix/decomposition.py +++ b/torchflows/bijections/finite/matrix/decomposition.py @@ -3,7 +3,7 @@ import torch from torchflows.bijections.finite.matrix.base import InvertibleMatrix -from torchflows.bijections.finite.matrix.orthogonal import HouseholderOrthogonalMatrix +from torchflows.bijections.finite.matrix.orthogonal import HouseholderProductMatrix from torchflows.bijections.finite.matrix.triangular import LowerTriangularInvertibleMatrix, \ UpperTriangularInvertibleMatrix @@ -27,7 +27,7 @@ def log_det_project(self): class QRMatrix(InvertibleMatrix): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): super().__init__(event_shape, **kwargs) - self.orthogonal = HouseholderOrthogonalMatrix(self.n_dim, **kwargs) + self.orthogonal = HouseholderProductMatrix(self.n_dim, **kwargs) self.upper = UpperTriangularInvertibleMatrix(self.n_dim, **kwargs) def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: diff --git a/torchflows/bijections/finite/matrix/orthogonal.py b/torchflows/bijections/finite/matrix/orthogonal.py index b0c9861..1612d11 100644 --- a/torchflows/bijections/finite/matrix/orthogonal.py +++ b/torchflows/bijections/finite/matrix/orthogonal.py @@ -6,38 +6,37 @@ from torchflows.bijections.finite.matrix.base import InvertibleMatrix -class HouseholderOrthogonalMatrix(InvertibleMatrix): +class HouseholderProductMatrix(InvertibleMatrix): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_factors: int = None, **kwargs): super().__init__(event_shape, **kwargs) if n_factors is None: - n_factors = min(5, self.n_dim) + n_factors = min(5, self.n_dim // 2) assert 1 <= n_factors <= self.n_dim self.v = nn.Parameter(torch.randn(n_factors, self.n_dim) / self.n_dim ** 2 + torch.eye(n_factors, self.n_dim)) + # self.v = nn.Parameter(torch.randn(n_factors, self.n_dim)) + self.tau = 2 - def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: + def apply_flat_transformation(self, x_flat: torch.Tensor, factors: torch.Tensor) -> torch.Tensor: batch_shape = x_flat.shape[:-1] z_flat = x_flat.clone() # (*batch_shape, self.n_dim) - for i in range(self.v.shape[0]): # Apply each Householder transformation in reverse order - v = self.v[i] # (self.n_dim,) - alpha = 2 * torch.einsum('i,...i->...', v, z_flat)[..., None] # (*batch_shape, 1) - v = v[[None] * len(batch_shape)] # (1, ..., 1, self.n_dim) with len(v.shape) == len(batch_shape) + 1 - z_flat = z_flat - alpha * v + assert len(factors) == self.v.shape[0] + for v in factors: + # v.shape == (self.n_dim,) + dot = torch.einsum('i,...i->...', v, z_flat)[..., None] # (*batch_shape, self.n_dim) + v_unsqueezed = v[[None] * len(batch_shape)] # (*batch_shape, self.n_dim) + scalar = self.tau / torch.sum(torch.square(v)) + z_flat = z_flat - scalar * (v_unsqueezed * dot).squeeze(-1) return z_flat + def project_flat(self, x_flat: torch.Tensor, context_flat: torch.Tensor = None) -> torch.Tensor: + return self.apply_flat_transformation(x_flat, self.v) + def solve_flat(self, b_flat: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: - # Same code as project, just the reverse matrix order - batch_shape = b_flat.shape[:-1] - x_flat = b_flat.clone() # (*batch_shape, self.n_dim) - for i in range(self.v.shape[0] - 1, -1, -1): # Apply each Householder transformation in reverse order - v = self.v[i] # (self.n_dim,) - alpha = 2 * torch.einsum('i,...i->...', v, x_flat)[..., None] # (*batch_shape, 1) - v = v[[None] * len(batch_shape)] # (1, ..., 1, self.n_dim) with len(v.shape) == len(batch_shape) + 1 - x_flat = x_flat - alpha * v - return x_flat + return self.apply_flat_transformation(b_flat, self.v.flip(0)) def log_det_project(self) -> torch.Tensor: - return torch.tensor(0.0).to(self.device_buffer.device) + return torch.zeros(1).to(self.device_buffer.device) def __matmul__(self, other: torch.Tensor) -> torch.Tensor: - return self.project_flat(other) \ No newline at end of file + return self.project_flat(other) diff --git a/torchflows/bijections/finite/matrix/util.py b/torchflows/bijections/finite/matrix/util.py index b815433..6d28776 100644 --- a/torchflows/bijections/finite/matrix/util.py +++ b/torchflows/bijections/finite/matrix/util.py @@ -1,9 +1,9 @@ import torch -from torchflows.bijections.finite.matrix import HouseholderOrthogonalMatrix +from torchflows.bijections.finite.matrix import HouseholderProductMatrix -def matmul_with_householder(a: torch.Tensor, q: HouseholderOrthogonalMatrix): +def matmul_with_householder(a: torch.Tensor, q: HouseholderProductMatrix): product = a for i in range(len(q.v)): product = product - 2 * q.v[i][:, None] * torch.matmul(q.v[i][None], product) diff --git a/torchflows/bijections/finite/residual/sylvester.py b/torchflows/bijections/finite/residual/sylvester.py index 9029967..f5c2f0e 100644 --- a/torchflows/bijections/finite/residual/sylvester.py +++ b/torchflows/bijections/finite/residual/sylvester.py @@ -5,7 +5,7 @@ from Cython.Shadow import returns from torchflows.bijections.finite.matrix import UpperTriangularInvertibleMatrix, IdentityMatrix, \ - HouseholderOrthogonalMatrix + HouseholderProductMatrix from torchflows.bijections.finite.matrix.permutation import PermutationMatrix, RandomPermutationMatrix from torchflows.bijections.finite.matrix.util import matmul_with_householder from torchflows.bijections.finite.residual.base import ClassicResidualBijection @@ -69,7 +69,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. class HouseholderSylvester(BaseSylvester): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): super().__init__(event_shape, **kwargs) - self.register_module('q', HouseholderOrthogonalMatrix(event_shape, n_factors=self.m)) + self.register_module('q', HouseholderProductMatrix(event_shape, n_factors=self.m)) def compute_u(self): return self.q.project_flat(self.r.compute_matrix()) From 6cebc7b6cd539ab4ced2ce6c2293866c83b49f50 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Wed, 27 Nov 2024 00:16:38 +0100 Subject: [PATCH 4/4] Remove bad import --- torchflows/bijections/finite/residual/sylvester.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchflows/bijections/finite/residual/sylvester.py b/torchflows/bijections/finite/residual/sylvester.py index f5c2f0e..1c76c6f 100644 --- a/torchflows/bijections/finite/residual/sylvester.py +++ b/torchflows/bijections/finite/residual/sylvester.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -from Cython.Shadow import returns from torchflows.bijections.finite.matrix import UpperTriangularInvertibleMatrix, IdentityMatrix, \ HouseholderProductMatrix