From 5a428b8814c757f3939e4aff3e8df355e5dad386 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Tue, 9 Jul 2024 01:44:50 +0200 Subject: [PATCH] Add best epoch logging to SVI, change default number of samples in SVI to 1 --- normalizing_flows/flows.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 43eb35d..00aa946 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -148,6 +148,11 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): if show_progress: if val_loss is None: iterator.set_postfix_str(f'Training loss (batch): {train_loss:.4f}') + elif early_stopping: + iterator.set_postfix_str( + f'Training loss (batch): {train_loss:.4f}, ' + f'Validation loss: {val_loss:.4f} [best: {best_val_loss:.4f} @ {best_epoch}]' + ) else: iterator.set_postfix_str( f'Training loss (batch): {train_loss:.4f}, ' @@ -189,7 +194,7 @@ def variational_fit(self, target_log_prob: callable, n_epochs: int = 500, lr: float = 0.05, - n_samples: int = 1000, + n_samples: int = 1, early_stopping: bool = False, early_stopping_threshold: int = 50, keep_best_weights: bool = True,