diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 43f949b..4bf670b 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -152,6 +152,8 @@ def fit(self, :param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs. :param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs. """ + self.bijection.train() + # Compute the number of event dimensions n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape))) @@ -249,6 +251,8 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): if x_val is not None and keep_best_weights: self.load_state_dict(best_weights) + self.bijection.eval() + def variational_fit(self, target, n_epochs: int = 10,