Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Nov 27, 2024
2 parents 62b935d + 6cebc7b commit ed213bd
Show file tree
Hide file tree
Showing 18 changed files with 699 additions and 342 deletions.
34 changes: 25 additions & 9 deletions test/test_autograd_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from torchflows.bijections.finite.autoregressive.architectures import NICE, RealNVP, CouplingRQNSF, MAF, IAF, \
InverseAutoregressiveRQNSF, MaskedAutoregressiveRQNSF
from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \
LRSCoupling, LinearRQSCoupling
from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR
LRSCoupling, LinearRQSCoupling, ElementwiseRQSpline
from torchflows.bijections.finite.matrix import HouseholderProductMatrix, LowerTriangularInvertibleMatrix, \
UpperTriangularInvertibleMatrix, IdentityMatrix, RandomPermutationMatrix, ReversePermutationMatrix, QRMatrix, \
LUMatrix
from torchflows.bijections.finite.residual.architectures import InvertibleResNet, ResFlow, ProximalResFlow
from torchflows.utils import get_batch_shape
from test.constants import __test_constants
Expand Down Expand Up @@ -43,19 +45,33 @@ def assert_valid_log_probability_gradient(bijection: Bijection, x: torch.Tensor,


@pytest.mark.parametrize('bijection_class', [
LU,
ReversePermutation,
ElementwiseScale,
LowerTriangular,
Orthogonal,
QR,
ElementwiseAffine,
ElementwiseShift
ElementwiseShift,
ElementwiseRQSpline
])
@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape'])
@pytest.mark.parametrize('event_shape', __test_constants['event_shape'])
@pytest.mark.parametrize('context_shape', __test_constants['context_shape'])
def test_linear(bijection_class: Bijection, batch_shape: Tuple, event_shape: Tuple, context_shape: Tuple):
def test_elementwise(bijection_class: Bijection, batch_shape: Tuple, event_shape: Tuple, context_shape: Tuple):
bijection, x, context = setup_data(bijection_class, batch_shape, event_shape, context_shape)
assert_valid_log_probability_gradient(bijection, x, context)


@pytest.mark.parametrize('bijection_class', [
IdentityMatrix,
RandomPermutationMatrix,
ReversePermutationMatrix,
LowerTriangularInvertibleMatrix,
UpperTriangularInvertibleMatrix,
HouseholderProductMatrix,
QRMatrix,
LUMatrix,
])
@pytest.mark.parametrize('batch_shape', __test_constants['batch_shape'])
@pytest.mark.parametrize('event_shape', __test_constants['event_shape'])
@pytest.mark.parametrize('context_shape', __test_constants['context_shape'])
def test_matrix(bijection_class: Bijection, batch_shape: Tuple, event_shape: Tuple, context_shape: Tuple):
bijection, x, context = setup_data(bijection_class, batch_shape, event_shape, context_shape)
assert_valid_log_probability_gradient(bijection, x, context)

Expand Down
21 changes: 14 additions & 7 deletions test/test_deepcopy.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
from copy import deepcopy

import pytest
import torch

from torchflows import RNODE, Flow
from torchflows import RNODE, Flow, Sylvester, RealNVP
from torchflows.bijections.base import invert


def test_basic():
@pytest.mark.parametrize('flow_class', [RNODE, Sylvester, RealNVP])
def test_basic(flow_class):
torch.manual_seed(0)
b = RNODE(event_shape=(10,))
b = flow_class(event_shape=(10,))
deepcopy(b)


def test_post_variational_fit():
@pytest.mark.parametrize('flow_class', [RNODE, Sylvester, RealNVP])
def test_post_variational_fit(flow_class):
torch.manual_seed(0)
b = RNODE(event_shape=(10,))
b = flow_class(event_shape=(10,))
f = Flow(b)
f.variational_fit(lambda x: torch.sum(-x ** 2), n_epochs=2)
deepcopy(b)

def test_post_fit():
@pytest.mark.parametrize('flow_class', [RNODE, Sylvester, RealNVP])
def test_post_fit(flow_class):
torch.manual_seed(0)
b = RNODE(event_shape=(10,))
b = flow_class(event_shape=(10,))
if isinstance(b, Sylvester):
b = invert(b)
f = Flow(b)
f.fit(x_train=torch.randn(3, *b.event_shape), n_epochs=2)
deepcopy(b)
16 changes: 8 additions & 8 deletions test/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
MaskedAutoregressiveRQNSF
from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \
ElementwiseRQSpline
from torchflows.bijections.finite.linear import LowerTriangular, LU, QR
from torchflows.bijections.finite.matrix import LowerTriangularInvertibleMatrix, LUMatrix, QRMatrix


@pytest.mark.skip(reason='Takes too long, fit quality is architecture-dependent')
@pytest.mark.parametrize('bijection_class', [
LowerTriangular,
LowerTriangularInvertibleMatrix,
ElementwiseScale,
LU,
QR,
LUMatrix,
QRMatrix,
ElementwiseAffine,
ElementwiseShift,
ElementwiseRQSpline,
Expand Down Expand Up @@ -83,9 +83,9 @@ def test_diagonal_gaussian_elementwise_scale():
@pytest.mark.skip(reason='Takes too long, fit quality is architecture-dependent')
@pytest.mark.parametrize('bijection_class',
[
LowerTriangular,
LU,
QR,
LowerTriangularInvertibleMatrix,
LUMatrix,
QRMatrix,
MaskedAutoregressiveRQNSF,
ElementwiseRQSpline,
ElementwiseAffine,
Expand All @@ -102,7 +102,7 @@ def test_diagonal_gaussian_1(bijection_class):
x = torch.randn(size=(n_data, n_dim)) * sigma
bijection = bijection_class(event_shape=(n_dim,))
flow = Flow(bijection)
if isinstance(bijection, LowerTriangular):
if isinstance(bijection, LowerTriangularInvertibleMatrix):
flow.fit(x, n_epochs=100)
else:
flow.fit(x, n_epochs=25)
Expand Down
13 changes: 7 additions & 6 deletions test/test_reconstruction_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
InverseAutoregressiveRQNSF, MaskedAutoregressiveRQNSF
from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \
LRSCoupling, LinearRQSCoupling, ActNorm, DenseSigmoidalCoupling, DeepDenseSigmoidalCoupling, DeepSigmoidalCoupling
from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR
from torchflows.bijections.finite.matrix import LUMatrix, ReversePermutationMatrix, LowerTriangularInvertibleMatrix, \
HouseholderProductMatrix, QRMatrix
from torchflows.bijections.finite.residual.architectures import ResFlow, InvertibleResNet, ProximalResFlow
from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock
from torchflows.bijections.finite.residual.planar import Planar
Expand Down Expand Up @@ -122,12 +123,12 @@ def assert_valid_reconstruction_continuous(bijection: ContinuousBijection,


@pytest.mark.parametrize('bijection_class', [
LU,
ReversePermutation,
LUMatrix,
ReversePermutationMatrix,
ElementwiseScale,
LowerTriangular,
Orthogonal,
QR,
LowerTriangularInvertibleMatrix,
HouseholderProductMatrix,
QRMatrix,
ElementwiseAffine,
ElementwiseShift,
ActNorm
Expand Down
Loading

0 comments on commit ed213bd

Please sign in to comment.