From 03fe60fac49f93d6f40345cdb22ac59a27082be2 Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 9 Aug 2024 13:53:38 +0200 Subject: [PATCH] Skip variational fit if NF has no parameters --- normalizing_flows/flows.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index a177ce9..e43a568 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -263,6 +263,12 @@ def variational_fit(self, :param float n_samples: number of samples to estimate the variational loss in each training step. :param bool show_progress: if True, show a progress bar during training. """ + if len(list(self.parameters())) == 0: + # If the flow has no trainable parameters, do nothing + return + + self.train() + optimizer = torch.optim.AdamW(self.parameters(), lr=lr) best_loss = torch.inf best_epoch = 0 @@ -290,6 +296,8 @@ def variational_fit(self, if keep_best_weights: self.load_state_dict(best_weights) + self.eval() + class Flow(BaseFlow): """