From 962472566012489587adf2e25fb9dffbe89ca908 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Wed, 27 Dec 2023 19:09:21 +0100 Subject: [PATCH] Handle training=True for residual bijections in Flow.fit --- normalizing_flows/flows.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 43f949b..4bf670b 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -152,6 +152,8 @@ def fit(self, :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. """ + self.bijection.train() + # Compute the number of event dimensions n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) @@ -249,6 +251,8 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): if x_val is not None and keep_best_weights: self.load_state_dict(best_weights) + self.bijection.eval() + def variational_fit(self, target, n_epochs: int = 10,