diff --git a/test/test_fit_conv_residual_flow.py b/test/test_fit_conv_residual_flow.py new file mode 100644 index 0000000..857700c --- /dev/null +++ b/test/test_fit_conv_residual_flow.py @@ -0,0 +1,13 @@ +import pytest +import torch + +from torchflows import Flow +from torchflows.architectures import ConvolutionalResFlow, ConvolutionalInvertibleResNet + + +@pytest.mark.parametrize('arch_cls', [ConvolutionalResFlow, ConvolutionalInvertibleResNet]) +def test_basic(arch_cls): + torch.manual_seed(0) + event_shape = (3, 20, 20) + flow = Flow(arch_cls(event_shape)) + flow.fit(torch.randn(size=(5, *event_shape)), n_epochs=20) diff --git a/torchflows/bijections/finite/residual/base.py b/torchflows/bijections/finite/residual/base.py index d664b48..d745be9 100644 --- a/torchflows/bijections/finite/residual/base.py +++ b/torchflows/bijections/finite/residual/base.py @@ -45,7 +45,7 @@ def forward(self, else: x_flat = flatten_batch(x.clone(), batch_shape) x_flat.requires_grad_(True) - log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) + log_det = unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) return z, log_det @@ -68,7 +68,7 @@ def inverse(self, else: x_flat = flatten_batch(x.clone(), batch_shape) x_flat.requires_grad_(True) - log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) + log_det = unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) return x, log_det diff --git a/torchflows/bijections/finite/residual/iterative.py b/torchflows/bijections/finite/residual/iterative.py index ef8e3fa..30843d4 100644 --- a/torchflows/bijections/finite/residual/iterative.py +++ b/torchflows/bijections/finite/residual/iterative.py @@ -12,7 +12,7 @@ class SpectralMatrix(nn.Module): def __init__(self, shape: Tuple[int, int], c: float = 0.7, n_iterations: int = 5): super().__init__() - self.data = torch.randn(size=shape) + self.data = nn.Parameter(torch.randn(size=shape)) self.c = c self.n_iterations = n_iterations @@ -22,7 +22,7 @@ def power_iteration(self, w): # Spectral Normalization for Generative Adversarial Networks - Miyato et al. - 2018 # Get maximum singular value of rectangular matrix w - u = torch.randn(self.data.shape[1], 1) + u = torch.randn(self.data.shape[1], 1).to(w) v = None w = w.T @@ -39,9 +39,8 @@ def power_iteration(self, w): def normalized(self): # Estimate sigma - sigma = self.power_iteration(self.data) - # ratio = self.c / sigma - # return self.w * (ratio ** (ratio < 1)) + with torch.no_grad(): + sigma = self.power_iteration(self.data) return self.data / sigma diff --git a/torchflows/flows.py b/torchflows/flows.py index 0c0f5a3..4713509 100644 --- a/torchflows/flows.py +++ b/torchflows/flows.py @@ -235,28 +235,27 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): # Compute validation loss at the end of each epoch # Validation loss will be displayed at the start of the next epoch if x_val is not None: - with torch.no_grad(): - # Compute validation loss - val_loss = 0.0 - for val_batch in val_loader: - val_loss += compute_batch_loss(val_batch, reduction=torch.sum) - val_loss /= len(x_val) - val_loss += self.regularization() - - # Check if validation loss is the lowest so far - if val_loss < best_val_loss: - best_val_loss = val_loss - best_epoch = epoch - - # Store current weights - if keep_best_weights: - if best_epoch == epoch: - best_weights = deepcopy(self.state_dict()) - - # Optionally stop training early - if early_stopping: - if epoch - best_epoch > early_stopping_threshold: - break + # Compute validation loss + val_loss = 0.0 + for val_batch in val_loader: + val_loss += compute_batch_loss(val_batch, reduction=torch.sum).detach() + val_loss /= len(x_val) + val_loss += self.regularization() + + # Check if validation loss is the lowest so far + if val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + + # Store current weights + if keep_best_weights: + if best_epoch == epoch: + best_weights = deepcopy(self.state_dict()) + + # Optionally stop training early + if early_stopping: + if epoch - best_epoch > early_stopping_threshold: + break if x_val is not None and keep_best_weights: self.load_state_dict(best_weights)