Skip to content

Commit

Permalink
Add early stopping to variational_fit
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Jul 8, 2024
1 parent 5e430ea commit ee6afea
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions normalizing_flows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def variational_fit(self,
n_epochs: int = 500,
lr: float = 0.05,
n_samples: int = 1000,
early_stopping: bool = False,
early_stopping_threshold: int = 50,
keep_best_weights: bool = True,
show_progress: bool = False):
"""
Train a distribution with stochastic variational inference.
Expand All @@ -210,15 +213,31 @@ def variational_fit(self,
"""
iterator = tqdm(range(n_epochs), desc='Fitting with SVI', disable=not show_progress)
optimizer = torch.optim.AdamW(self.parameters(), lr=lr)
best_loss = torch.inf
best_epoch = 0
best_weights = deepcopy(self.state_dict())

for _ in iterator:
for epoch in iterator:
optimizer.zero_grad()
flow_x, flow_log_prob = self.sample(n_samples, return_log_prob=True)
loss = -torch.mean(target_log_prob(flow_x) + flow_log_prob)
loss += self.regularization()
loss.backward()
optimizer.step()
iterator.set_postfix_str(f'Loss: {loss:.4f}')

if loss < best_loss:
best_loss = loss
best_epoch = epoch
if keep_best_weights:
best_weights = deepcopy(self.state_dict())

iterator.set_postfix_str(f'Loss: {loss:.4f} [best: {best_loss:.4f} @ {best_epoch}]')

if epoch - best_epoch > early_stopping_threshold and early_stopping:
break

if keep_best_weights:
self.load_state_dict(best_weights)


class Flow(BaseFlow):
Expand Down Expand Up @@ -288,9 +307,11 @@ def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, re
if no_grad:
z = z.detach()
with torch.no_grad():
x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), context=context)
x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape),
context=context)
else:
x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), context=context)
x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape),
context=context)
x = x.to(self.get_device())

if return_log_prob:
Expand Down

0 comments on commit ee6afea

Please sign in to comment.