Skip to content

Commit

Permalink
Skip variational fit if NF has no parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Aug 9, 2024
1 parent bb6f31f commit 03fe60f
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions normalizing_flows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ def variational_fit(self,
:param float n_samples: number of samples to estimate the variational loss in each training step.
:param bool show_progress: if True, show a progress bar during training.
"""
if len(list(self.parameters())) == 0:
# If the flow has no trainable parameters, do nothing
return

self.train()

optimizer = torch.optim.AdamW(self.parameters(), lr=lr)
best_loss = torch.inf
best_epoch = 0
Expand Down Expand Up @@ -290,6 +296,8 @@ def variational_fit(self,
if keep_best_weights:
self.load_state_dict(best_weights)

self.eval()


class Flow(BaseFlow):
"""
Expand Down

0 comments on commit 03fe60f

Please sign in to comment.