From 23d319e26ca52dfc23f5b4a6c502490b98519acb Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 4 Jul 2024 15:49:52 +0200 Subject: [PATCH] Fix validation loss computation --- normalizing_flows/flows.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 602801b..f50b1b8 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -161,8 +161,8 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): # Compute validation loss val_loss = 0.0 for val_batch in val_loader: - n_batch_data = len(val_batch[0]) - val_loss += compute_batch_loss(val_batch, reduction=torch.sum) / n_batch_data + val_loss += compute_batch_loss(val_batch, reduction=torch.sum) + val_loss /= len(x_val) val_loss += self.regularization() # Check if validation loss is the lowest so far