diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 43eb35d..00aa946 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -148,6 +148,11 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): if show_progress: if val_loss is None: iterator.set_postfix_str(f'Training loss (batch): {train_loss:.4f}') + elif early_stopping: + iterator.set_postfix_str( + f'Training loss (batch): {train_loss:.4f}, ' + f'Validation loss: {val_loss:.4f} [best: {best_val_loss:.4f} @ {best_epoch}]' + ) else: iterator.set_postfix_str( f'Training loss (batch): {train_loss:.4f}, ' @@ -189,7 +194,7 @@ def variational_fit(self, target_log_prob: callable, n_epochs: int = 500, lr: float = 0.05, - n_samples: int = 1000, + n_samples: int = 1, early_stopping: bool = False, early_stopping_threshold: int = 50, keep_best_weights: bool = True,