diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index dedf92e..f803971 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -164,14 +164,15 @@ def fit(self, ) # Process validation data - val_loader = create_data_loader( - x_val, - w_val, - context_val, - "validation", - batch_size=batch_size, - shuffle=shuffle - ) + if x_val is not None: + val_loader = create_data_loader( + x_val, + w_val, + context_val, + "validation", + batch_size=batch_size, + shuffle=shuffle + ) def compute_batch_loss(batch_, reduction: callable = torch.mean): batch_x, batch_weights = batch_[:2] diff --git a/normalizing_flows/utils.py b/normalizing_flows/utils.py index db96245..7a0a5e1 100644 --- a/normalizing_flows/utils.py +++ b/normalizing_flows/utils.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Tuple, Union, Optional import torch from torch.utils.data import TensorDataset, DataLoader @@ -187,8 +187,8 @@ def log_prob(self, value): def create_data_loader(x: torch.Tensor, - weights: torch.Tensor, - context: torch.Tensor, + weights: Optional[torch.Tensor], + context: Optional[torch.Tensor], label: str, **kwargs): """ diff --git a/test/test_fit.py b/test/test_fit.py index 98c9fee..ebf89f3 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -3,6 +3,7 @@ from normalizing_flows import Flow from normalizing_flows.bijections import NICE, RealNVP, MAF, ElementwiseAffine, ElementwiseShift, ElementwiseRQSpline, \ CouplingRQNSF, MaskedAutoregressiveRQNSF, LowerTriangular, ElementwiseScale, QR, LU +from test.constants import __test_constants @pytest.mark.skip(reason='Takes too long, fit quality is architecture-dependent') @@ -107,3 +108,65 @@ def test_diagonal_gaussian_1(bijection_class): relative_error = max((x_std - sigma.ravel()).abs() / sigma.ravel()) assert relative_error < 0.1 + + +@pytest.mark.parametrize("n_train", [1, 10, 2200]) +@pytest.mark.parametrize("event_shape", __test_constants["event_shape"]) +def test_fit_basic(n_train, event_shape): + torch.manual_seed(0) + x_train = torch.randn(size=(n_train, *event_shape)) + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2) + + +@pytest.mark.parametrize("n_train", [1, 10, 4000]) +@pytest.mark.parametrize("n_val", [1, 10, 2400]) +def test_fit_with_validation_data(n_train, n_val): + torch.manual_seed(0) + + event_shape = (2, 3) + + x_train = torch.randn(size=(n_train, *event_shape)) + x_val = torch.randn(size=(n_val, *event_shape)) + + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2, x_val=x_val) + + +@pytest.mark.parametrize("n_train", [1, 10, 2200]) +@pytest.mark.parametrize("event_shape", __test_constants["event_shape"]) +@pytest.mark.parametrize("context_shape", __test_constants["context_shape"]) +def test_fit_with_training_context(n_train, event_shape, context_shape): + torch.manual_seed(0) + x_train = torch.randn(size=(n_train, *event_shape)) + if context_shape is None: + c_train = None + else: + c_train = torch.randn(size=(n_train, *context_shape)) + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2, context_train=c_train) + + +@pytest.mark.parametrize("n_train", [1, 10, 2200]) +@pytest.mark.parametrize("n_val", [1, 10, 2200]) +@pytest.mark.parametrize("event_shape", __test_constants["event_shape"]) +@pytest.mark.parametrize("context_shape", __test_constants["context_shape"]) +def test_fit_with_context_and_validation_data(n_train, n_val, event_shape, context_shape): + torch.manual_seed(0) + + # Setup training data + x_train = torch.randn(size=(n_train, *event_shape)) + if context_shape is None: + c_train = None + else: + c_train = torch.randn(size=(n_train, *context_shape)) + + # Setup validation data + x_val = torch.randn(size=(n_val, *event_shape)) + if context_shape is None: + c_val = None + else: + c_val = torch.randn(size=(n_val, *context_shape)) + + flow = Flow(RealNVP(event_shape)) + flow.fit(x_train, n_epochs=2, context_train=c_train, x_val=x_val, context_val=c_val)