Skip to content

Commit

Permalink
Check if validation data exists in Flow.fit, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Nov 9, 2023
1 parent e058ba7 commit 54574a8
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 11 deletions.
17 changes: 9 additions & 8 deletions normalizing_flows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,15 @@ def fit(self,
)

# Process validation data
val_loader = create_data_loader(
x_val,
w_val,
context_val,
"validation",
batch_size=batch_size,
shuffle=shuffle
)
if x_val is not None:
val_loader = create_data_loader(
x_val,
w_val,
context_val,
"validation",
batch_size=batch_size,
shuffle=shuffle
)

def compute_batch_loss(batch_, reduction: callable = torch.mean):
batch_x, batch_weights = batch_[:2]
Expand Down
6 changes: 3 additions & 3 deletions normalizing_flows/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Union
from typing import Tuple, Union, Optional
import torch
from torch.utils.data import TensorDataset, DataLoader

Expand Down Expand Up @@ -187,8 +187,8 @@ def log_prob(self, value):


def create_data_loader(x: torch.Tensor,
weights: torch.Tensor,
context: torch.Tensor,
weights: Optional[torch.Tensor],
context: Optional[torch.Tensor],
label: str,
**kwargs):
"""
Expand Down
63 changes: 63 additions & 0 deletions test/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from normalizing_flows import Flow
from normalizing_flows.bijections import NICE, RealNVP, MAF, ElementwiseAffine, ElementwiseShift, ElementwiseRQSpline, \
CouplingRQNSF, MaskedAutoregressiveRQNSF, LowerTriangular, ElementwiseScale, QR, LU
from test.constants import __test_constants


@pytest.mark.skip(reason='Takes too long, fit quality is architecture-dependent')
Expand Down Expand Up @@ -107,3 +108,65 @@ def test_diagonal_gaussian_1(bijection_class):
relative_error = max((x_std - sigma.ravel()).abs() / sigma.ravel())

assert relative_error < 0.1


@pytest.mark.parametrize("n_train", [1, 10, 2200])
@pytest.mark.parametrize("event_shape", __test_constants["event_shape"])
def test_fit_basic(n_train, event_shape):
torch.manual_seed(0)
x_train = torch.randn(size=(n_train, *event_shape))
flow = Flow(RealNVP(event_shape))
flow.fit(x_train, n_epochs=2)


@pytest.mark.parametrize("n_train", [1, 10, 4000])
@pytest.mark.parametrize("n_val", [1, 10, 2400])
def test_fit_with_validation_data(n_train, n_val):
torch.manual_seed(0)

event_shape = (2, 3)

x_train = torch.randn(size=(n_train, *event_shape))
x_val = torch.randn(size=(n_val, *event_shape))

flow = Flow(RealNVP(event_shape))
flow.fit(x_train, n_epochs=2, x_val=x_val)


@pytest.mark.parametrize("n_train", [1, 10, 2200])
@pytest.mark.parametrize("event_shape", __test_constants["event_shape"])
@pytest.mark.parametrize("context_shape", __test_constants["context_shape"])
def test_fit_with_training_context(n_train, event_shape, context_shape):
torch.manual_seed(0)
x_train = torch.randn(size=(n_train, *event_shape))
if context_shape is None:
c_train = None
else:
c_train = torch.randn(size=(n_train, *context_shape))
flow = Flow(RealNVP(event_shape))
flow.fit(x_train, n_epochs=2, context_train=c_train)


@pytest.mark.parametrize("n_train", [1, 10, 2200])
@pytest.mark.parametrize("n_val", [1, 10, 2200])
@pytest.mark.parametrize("event_shape", __test_constants["event_shape"])
@pytest.mark.parametrize("context_shape", __test_constants["context_shape"])
def test_fit_with_context_and_validation_data(n_train, n_val, event_shape, context_shape):
torch.manual_seed(0)

# Setup training data
x_train = torch.randn(size=(n_train, *event_shape))
if context_shape is None:
c_train = None
else:
c_train = torch.randn(size=(n_train, *context_shape))

# Setup validation data
x_val = torch.randn(size=(n_val, *event_shape))
if context_shape is None:
c_val = None
else:
c_val = torch.randn(size=(n_val, *context_shape))

flow = Flow(RealNVP(event_shape))
flow.fit(x_train, n_epochs=2, context_train=c_train, x_val=x_val, context_val=c_val)

0 comments on commit 54574a8

Please sign in to comment.