diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index a177ce9..e43a568 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -263,6 +263,12 @@ def variational_fit(self, :param float n_samples: number of samples to estimate the variational loss in each training step. :param bool show_progress: if True, show a progress bar during training. """ + if len(list(self.parameters())) == 0: + # If the flow has no trainable parameters, do nothing + return + + self.train() + optimizer = torch.optim.AdamW(self.parameters(), lr=lr) best_loss = torch.inf best_epoch = 0 @@ -290,6 +296,8 @@ def variational_fit(self, if keep_best_weights: self.load_state_dict(best_weights) + self.eval() + class Flow(BaseFlow): """