Skip to content

Commit

Permalink
Add "keep best weights" and "early stopping" options to Flow.fit and
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Nov 9, 2023
1 parent 54574a8 commit 46385b5
Showing 1 changed file with 33 additions and 3 deletions.
36 changes: 33 additions & 3 deletions normalizing_flows/flows.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import Union, Tuple

import torch
Expand Down Expand Up @@ -125,13 +126,15 @@ def fit(self,
context_train: torch.Tensor = None,
x_val: torch.Tensor = None,
w_val: torch.Tensor = None,
context_val: torch.Tensor = None):
context_val: torch.Tensor = None,
keep_best_weights: bool = True,
early_stopping: bool = False,
early_stopping_threshold: int = 50):
"""
Fit the normalizing flow.
Fitting the flow means finding the parameters of the bijection that maximize the probability of training data.
Bijection parameters are iteratively updated for a specified number of epochs.
If validation data is provided, we keep the bijection weights with the highest probability of validation data.
If context data is provided, the normalizing flow learns the distribution of data conditional on context data.
:param x_train: training data with shape (n_training_data, *event_shape).
Expand All @@ -145,6 +148,9 @@ def fit(self,
:param x_val: validation data with shape (n_validation_data, *event_shape).
:param w_val: validation data weights with shape (n_validation_data,).
:param context_val: validation data context tensor with shape (n_validation_data, *context_shape).
:param keep_best_weights: if True and validation data is provided, keep the bijection weights with the highest probability of validation data.
:param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs.
:param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs.
"""
# Compute the number of event dimensions
n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape)))
Expand Down Expand Up @@ -174,6 +180,10 @@ def fit(self,
shuffle=shuffle
)

best_val_loss = torch.inf
best_epoch = 0
best_weights = deepcopy(self.state_dict())

def compute_batch_loss(batch_, reduction: callable = torch.mean):
batch_x, batch_weights = batch_[:2]
batch_context = batch_[2] if len(batch_) == 3 else None
Expand All @@ -188,7 +198,8 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean):
iterator = tqdm(range(n_epochs), desc='Fitting NF', disable=not show_progress)
optimizer = torch.optim.AdamW(self.parameters(), lr=lr)
val_loss = None
for _ in iterator:

for epoch in iterator:
for train_batch in train_loader:
optimizer.zero_grad()
train_loss = compute_batch_loss(train_batch, reduction=torch.mean)
Expand All @@ -210,13 +221,32 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean):
# Validation loss will be displayed at the start of the next epoch
if x_val is not None:
with torch.no_grad():
# Compute validation loss
val_loss = 0.0
for val_batch in val_loader:
n_batch_data = len(val_batch[0])
val_loss += compute_batch_loss(val_batch, reduction=torch.sum) / n_batch_data
if hasattr(self.bijection, 'regularization'):
val_loss += self.bijection.regularization()

# Check if validation loss is the lowest so far
if val_loss < best_val_loss:
best_val_loss = val_loss
best_epoch = epoch

# Store current weights
if keep_best_weights:
if best_epoch == epoch:
best_weights = deepcopy(self.state_dict())

# Optionally stop training early
if early_stopping:
if epoch - best_epoch > early_stopping_threshold:
break

if x_val is not None and keep_best_weights:
self.load_state_dict(best_weights)

def variational_fit(self,
target,
n_epochs: int = 10,
Expand Down

0 comments on commit 46385b5

Please sign in to comment.