Skip to content

Commit

Permalink
Add maximum training time option for NF training
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Sep 3, 2024
1 parent d1f43a8 commit 6232722
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions torchflows/flows.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from copy import deepcopy
from typing import Union, Tuple, List

Expand Down Expand Up @@ -80,7 +81,8 @@ def fit(self,
keep_best_weights: bool = True,
early_stopping: bool = False,
early_stopping_threshold: int = 50,
max_batch_size_mb: int = None):
max_batch_size_mb: int = None,
time_limit_seconds: Union[float, int] = None):
"""Fit the normalizing flow to a dataset.
Fitting the flow means finding the parameters of the bijection that maximize the probability of training data.
Expand All @@ -102,7 +104,10 @@ def fit(self,
:param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs.
:param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs.
:param int max_batch_size_mb: maximum batch size in megabytes.
:param Union[float, int] time_limit_seconds: maximum allowed time for training.
"""
t0 = time.time()

if len(list(self.parameters())) == 0:
# If the flow has no trainable parameters, do nothing
return
Expand Down Expand Up @@ -167,6 +172,10 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean):
val_loss = None

for epoch in (pbar := tqdm(range(n_epochs), desc='Fitting NF', disable=not show_progress)):
if time_limit_seconds is not None and time.time() - t0 >= time_limit_seconds:
print("Training time limit exceeded")
break

if (
adaptive_batch_size
and epoch % batch_size_adaptation_interval == batch_size_adaptation_interval - 1
Expand Down Expand Up @@ -263,7 +272,8 @@ def variational_fit(self,
early_stopping_threshold: int = 50,
keep_best_weights: bool = True,
show_progress: bool = False,
check_for_divergences: bool = False):
check_for_divergences: bool = False,
time_limit_seconds:Union[float, int] = None):
"""Train the normalizing flow to fit a target log probability.
Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset.
Expand All @@ -277,6 +287,8 @@ def variational_fit(self,
:param float n_samples: number of samples to estimate the variational loss in each training step.
:param bool show_progress: if True, show a progress bar during training.
"""
t0 = time.time()

if len(list(self.parameters())) == 0:
# If the flow has no trainable parameters, do nothing
return
Expand All @@ -292,6 +304,9 @@ def variational_fit(self,
n_divergences = 0

for epoch in (pbar := tqdm(range(n_epochs), desc='Fitting with SVI', disable=not show_progress)):
if time_limit_seconds is not None and time.time() - t0 >= time_limit_seconds:
print("Training time limit exceeded")
break
if check_for_divergences and not all([torch.isfinite(p).all() for p in self.parameters()]):
flow_training_diverged = True
print('Flow training diverged')
Expand Down

0 comments on commit 6232722

Please sign in to comment.