Skip to content

Commit

Permalink
Handle training=True for residual bijections in Flow.fit
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Dec 27, 2023
1 parent e82e79a commit 9624725
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions normalizing_flows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def fit(self,
:param early_stopping: if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs.
:param early_stopping_threshold: if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs.
"""
self.bijection.train()

# Compute the number of event dimensions
n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape)))

Expand Down Expand Up @@ -249,6 +251,8 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean):
if x_val is not None and keep_best_weights:
self.load_state_dict(best_weights)

self.bijection.eval()

def variational_fit(self,
target,
n_epochs: int = 10,
Expand Down

0 comments on commit 9624725

Please sign in to comment.