diff --git a/torchflows/flows.py b/torchflows/flows.py index 22bde3c..8bb31bb 100644 --- a/torchflows/flows.py +++ b/torchflows/flows.py @@ -79,7 +79,8 @@ def fit(self, context_val: torch.Tensor = None, keep_best_weights: bool = True, early_stopping: bool = False, - early_stopping_threshold: int = 50): + early_stopping_threshold: int = 50, + max_batch_size_mb: int = 2000): """Fit the normalizing flow to a dataset. Fitting the flow means finding the parameters of the bijection that maximize the probability of training data. @@ -100,6 +101,7 @@ def fit(self, :param keep_best_weights: if True and validation data is provided, keep the bijection weights with the highest probability of validation data. :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. + :param int max_batch_size_mb: maximum batch size in megabytes. """ if len(list(self.parameters())) == 0: # If the flow has no trainable parameters, do nothing @@ -114,6 +116,10 @@ def fit(self, elif isinstance(batch_size, str) and batch_size == "adaptive": min_batch_size = max(32, min(1024, len(x_train) // 100)) max_batch_size = min(4096, len(x_train) // 10) + + event_size_mb = self.event_size / 2 ** 20 + max_batch_size = max(1, min(max_batch_size, int(max_batch_size_mb / event_size_mb))) + batch_size_adaptation_interval = 10 # double the batch size every 10 epochs adaptive_batch_size = True batch_size = min_batch_size @@ -290,42 +296,53 @@ def variational_fit(self, print('Flow training diverged') print('Reverting to initial weights') break - - optimizer.zero_grad() - flow_x, flow_log_prob = self.sample(n_samples, return_log_prob=True) - target_log_prob_value = target_log_prob(flow_x) - loss = -torch.mean(target_log_prob_value + flow_log_prob) - loss += self.regularization() - epoch_diverged = False - if check_for_divergences: - if not torch.isfinite(loss): - epoch_diverged = True - if torch.max(torch.abs(flow_x)) > 1e8: - epoch_diverged = True - elif torch.max(torch.abs(flow_log_prob)) > 1e6: - epoch_diverged = True - elif torch.any(~torch.isfinite(flow_x)): - epoch_diverged = True - elif torch.any(~torch.isfinite(flow_log_prob)): - epoch_diverged = True - n_divergences += epoch_diverged + optimizer.zero_grad() - if not epoch_diverged: - loss.backward() - optimizer.step() - if loss < best_loss: - best_loss = loss - best_epoch = epoch - if keep_best_weights: - best_weights = deepcopy(self.state_dict()) - else: + try: + flow_x, flow_log_prob = self.sample(n_samples, return_log_prob=True) + target_log_prob_value = target_log_prob(flow_x) + loss = -torch.mean(target_log_prob_value + flow_log_prob) + loss += self.regularization() + + if check_for_divergences: + if not torch.isfinite(loss): + epoch_diverged = True + if torch.max(torch.abs(flow_x)) > 1e8: + epoch_diverged = True + elif torch.max(torch.abs(flow_log_prob)) > 1e6: + epoch_diverged = True + elif torch.any(~torch.isfinite(flow_x)): + epoch_diverged = True + elif torch.any(~torch.isfinite(flow_log_prob)): + epoch_diverged = True + + if not epoch_diverged: + loss.backward() + optimizer.step() + if loss < best_loss: + best_loss = loss + best_epoch = epoch + if keep_best_weights: + best_weights = deepcopy(self.state_dict()) + mean_flow_log_prob = flow_log_prob.mean() + mean_target_log_prob = target_log_prob_value.mean() + else: + loss = torch.nan + mean_flow_log_prob = torch.nan + mean_target_log_prob = torch.nan + except ValueError: + epoch_diverged = True loss = torch.nan + mean_flow_log_prob = torch.nan + mean_target_log_prob = torch.nan + + n_divergences += epoch_diverged pbar.set_postfix_str(f'Loss: {loss:.4f} [best: {best_loss:.4f} @ {best_epoch}], ' f'divergences: {n_divergences}, ' - f'flow log_prob: {flow_log_prob.mean():.2f}, ' - f'target log_prob: {target_log_prob_value.mean():.2f}') + f'flow log_prob: {mean_flow_log_prob:.2f}, ' + f'target log_prob: {mean_target_log_prob:.2f}') if epoch - best_epoch > early_stopping_threshold and early_stopping: break