From 6232722ea476c40daac852c864a102be68369ca9 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 3 Sep 2024 22:28:04 +0200 Subject: [PATCH] Add maximum training time option for NF training --- torchflows/flows.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/torchflows/flows.py b/torchflows/flows.py index 79fc656..0c0f5a3 100644 --- a/torchflows/flows.py +++ b/torchflows/flows.py @@ -1,3 +1,4 @@ +import time from copy import deepcopy from typing import Union, Tuple, List @@ -80,7 +81,8 @@ def fit(self, keep_best_weights: bool = True, early_stopping: bool = False, early_stopping_threshold: int = 50, - max_batch_size_mb: int = None): + max_batch_size_mb: int = None, + time_limit_seconds: Union[float, int] = None): """Fit the normalizing flow to a dataset. Fitting the flow means finding the parameters of the bijection that maximize the probability of training data. @@ -102,7 +104,10 @@ def fit(self, :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. :param int max_batch_size_mb: maximum batch size in megabytes. + :param Union[float, int] time_limit_seconds: maximum allowed time for training. """ + t0 = time.time() + if len(list(self.parameters())) == 0: # If the flow has no trainable parameters, do nothing return @@ -167,6 +172,10 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): val_loss = None for epoch in (pbar := tqdm(range(n_epochs), desc='Fitting NF', disable=not show_progress)): + if time_limit_seconds is not None and time.time() - t0 >= time_limit_seconds: + print("Training time limit exceeded") + break + if ( adaptive_batch_size and epoch % batch_size_adaptation_interval == batch_size_adaptation_interval - 1 @@ -263,7 +272,8 @@ def variational_fit(self, early_stopping_threshold: int = 50, keep_best_weights: bool = True, show_progress: bool = False, - check_for_divergences: bool = False): + check_for_divergences: bool = False, + time_limit_seconds:Union[float, int] = None): """Train the normalizing flow to fit a target log probability. Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset. @@ -277,6 +287,8 @@ def variational_fit(self, :param float n_samples: number of samples to estimate the variational loss in each training step. :param bool show_progress: if True, show a progress bar during training. """ + t0 = time.time() + if len(list(self.parameters())) == 0: # If the flow has no trainable parameters, do nothing return @@ -292,6 +304,9 @@ def variational_fit(self, n_divergences = 0 for epoch in (pbar := tqdm(range(n_epochs), desc='Fitting with SVI', disable=not show_progress)): + if time_limit_seconds is not None and time.time() - t0 >= time_limit_seconds: + print("Training time limit exceeded") + break if check_for_divergences and not all([torch.isfinite(p).all() for p in self.parameters()]): flow_training_diverged = True print('Flow training diverged')