diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 602801b..f50b1b8 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -161,8 +161,8 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): # Compute validation loss val_loss = 0.0 for val_batch in val_loader: - n_batch_data = len(val_batch[0]) - val_loss += compute_batch_loss(val_batch, reduction=torch.sum) / n_batch_data + val_loss += compute_batch_loss(val_batch, reduction=torch.sum) + val_loss /= len(x_val) val_loss += self.regularization() # Check if validation loss is the lowest so far