Skip to content

Commit

Permalink
Fix residual flow fits
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Sep 4, 2024
1 parent 6232722 commit 9094cdb
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 29 deletions.
13 changes: 13 additions & 0 deletions test/test_fit_conv_residual_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
import torch

from torchflows import Flow
from torchflows.architectures import ConvolutionalResFlow, ConvolutionalInvertibleResNet


@pytest.mark.parametrize('arch_cls', [ConvolutionalResFlow, ConvolutionalInvertibleResNet])
def test_basic(arch_cls):
torch.manual_seed(0)
event_shape = (3, 20, 20)
flow = Flow(arch_cls(event_shape))
flow.fit(torch.randn(size=(5, *event_shape)), n_epochs=20)
4 changes: 2 additions & 2 deletions torchflows/bijections/finite/residual/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(self,
else:
x_flat = flatten_batch(x.clone(), batch_shape)
x_flat.requires_grad_(True)
log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape)
log_det = unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape)

return z, log_det

Expand All @@ -68,7 +68,7 @@ def inverse(self,
else:
x_flat = flatten_batch(x.clone(), batch_shape)
x_flat.requires_grad_(True)
log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape)
log_det = unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape)

return x, log_det

Expand Down
9 changes: 4 additions & 5 deletions torchflows/bijections/finite/residual/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class SpectralMatrix(nn.Module):
def __init__(self, shape: Tuple[int, int], c: float = 0.7, n_iterations: int = 5):
super().__init__()
self.data = torch.randn(size=shape)
self.data = nn.Parameter(torch.randn(size=shape))
self.c = c
self.n_iterations = n_iterations

Expand All @@ -22,7 +22,7 @@ def power_iteration(self, w):
# Spectral Normalization for Generative Adversarial Networks - Miyato et al. - 2018

# Get maximum singular value of rectangular matrix w
u = torch.randn(self.data.shape[1], 1)
u = torch.randn(self.data.shape[1], 1).to(w)
v = None

w = w.T
Expand All @@ -39,9 +39,8 @@ def power_iteration(self, w):

def normalized(self):
# Estimate sigma
sigma = self.power_iteration(self.data)
# ratio = self.c / sigma
# return self.w * (ratio ** (ratio < 1))
with torch.no_grad():
sigma = self.power_iteration(self.data)
return self.data / sigma


Expand Down
43 changes: 21 additions & 22 deletions torchflows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,28 +235,27 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean):
# Compute validation loss at the end of each epoch
# Validation loss will be displayed at the start of the next epoch
if x_val is not None:
with torch.no_grad():
# Compute validation loss
val_loss = 0.0
for val_batch in val_loader:
val_loss += compute_batch_loss(val_batch, reduction=torch.sum)
val_loss /= len(x_val)
val_loss += self.regularization()

# Check if validation loss is the lowest so far
if val_loss < best_val_loss:
best_val_loss = val_loss
best_epoch = epoch

# Store current weights
if keep_best_weights:
if best_epoch == epoch:
best_weights = deepcopy(self.state_dict())

# Optionally stop training early
if early_stopping:
if epoch - best_epoch > early_stopping_threshold:
break
# Compute validation loss
val_loss = 0.0
for val_batch in val_loader:
val_loss += compute_batch_loss(val_batch, reduction=torch.sum).detach()
val_loss /= len(x_val)
val_loss += self.regularization()

# Check if validation loss is the lowest so far
if val_loss < best_val_loss:
best_val_loss = val_loss
best_epoch = epoch

# Store current weights
if keep_best_weights:
if best_epoch == epoch:
best_weights = deepcopy(self.state_dict())

# Optionally stop training early
if early_stopping:
if epoch - best_epoch > early_stopping_threshold:
break

if x_val is not None and keep_best_weights:
self.load_state_dict(best_weights)
Expand Down

0 comments on commit 9094cdb

Please sign in to comment.