From 46385b5d1258be47776e9e07916611c6497d8739 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 9 Nov 2023 12:48:36 -0800 Subject: [PATCH] Add "keep best weights" and "early stopping" options to Flow.fit and --- normalizing_flows/flows.py | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index f803971..ac5f5d5 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Union, Tuple import torch @@ -125,13 +126,15 @@ def fit(self, context_train: torch.Tensor = None, x_val: torch.Tensor = None, w_val: torch.Tensor = None, - context_val: torch.Tensor = None): + context_val: torch.Tensor = None, + keep_best_weights: bool = True, + early_stopping: bool = False, + early_stopping_threshold: int = 50): """ Fit the normalizing flow. Fitting the flow means finding the parameters of the bijection that maximize the probability of training data. Bijection parameters are iteratively updated for a specified number of epochs. - If validation data is provided, we keep the bijection weights with the highest probability of validation data. If context data is provided, the normalizing flow learns the distribution of data conditional on context data. :param x_train: training data with shape (n_training_data, *event_shape). @@ -145,6 +148,9 @@ def fit(self, :param x_val: validation data with shape (n_validation_data, *event_shape). :param w_val: validation data weights with shape (n_validation_data,). :param context_val: validation data context tensor with shape (n_validation_data, *context_shape). + :param keep_best_weights: if True and validation data is provided, keep the bijection weights with the highest probability of validation data. + :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. + :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. """ # Compute the number of event dimensions n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) @@ -174,6 +180,10 @@ def fit(self, shuffle=shuffle ) + best_val_loss = torch.inf + best_epoch = 0 + best_weights = deepcopy(self.state_dict()) + def compute_batch_loss(batch_, reduction: callable = torch.mean): batch_x, batch_weights = batch_[:2] batch_context = batch_[2] if len(batch_) == 3 else None @@ -188,7 +198,8 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): iterator = tqdm(range(n_epochs), desc='Fitting NF', disable=not show_progress) optimizer = torch.optim.AdamW(self.parameters(), lr=lr) val_loss = None - for _ in iterator: + + for epoch in iterator: for train_batch in train_loader: optimizer.zero_grad() train_loss = compute_batch_loss(train_batch, reduction=torch.mean) @@ -210,6 +221,7 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): # Validation loss will be displayed at the start of the next epoch if x_val is not None: with torch.no_grad(): + # Compute validation loss val_loss = 0.0 for val_batch in val_loader: n_batch_data = len(val_batch[0]) @@ -217,6 +229,24 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): if hasattr(self.bijection, 'regularization'): val_loss += self.bijection.regularization() + # Check if validation loss is the lowest so far + if val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + + # Store current weights + if keep_best_weights: + if best_epoch == epoch: + best_weights = deepcopy(self.state_dict()) + + # Optionally stop training early + if early_stopping: + if epoch - best_epoch > early_stopping_threshold: + break + + if x_val is not None and keep_best_weights: + self.load_state_dict(best_weights) + def variational_fit(self, target, n_epochs: int = 10,