Skip to content

Commit

Permalink
Add option to check for divergences in variational fit
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Aug 24, 2024
1 parent 5bb8bb3 commit f65fd17
Showing 1 changed file with 44 additions and 11 deletions.
55 changes: 44 additions & 11 deletions torchflows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ def variational_fit(self,
early_stopping: bool = False,
early_stopping_threshold: int = 50,
keep_best_weights: bool = True,
show_progress: bool = False):
show_progress: bool = False,
check_for_divergences: bool = False):
"""Train the normalizing flow to fit a target log probability.
Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset.
Expand All @@ -276,31 +277,63 @@ def variational_fit(self,

self.train()

flow_training_diverged = False
optimizer = torch.optim.AdamW(self.parameters(), lr=lr)
best_loss = torch.inf
best_epoch = 0
initial_weights = deepcopy(self.state_dict())
best_weights = deepcopy(self.state_dict())
n_divergences = 0

for epoch in (pbar := tqdm(range(n_epochs), desc='Fitting with SVI', disable=not show_progress)):
if check_for_divergences and not all([torch.isfinite(p).all() for p in self.parameters()]):
flow_training_diverged = True
print('Flow training diverged')
print('Reverting to initial weights')
break

optimizer.zero_grad()
flow_x, flow_log_prob = self.sample(n_samples, return_log_prob=True)
loss = -torch.mean(target_log_prob(flow_x) + flow_log_prob)
target_log_prob_value = target_log_prob(flow_x)
loss = -torch.mean(target_log_prob_value + flow_log_prob)
loss += self.regularization()
loss.backward()
optimizer.step()

if loss < best_loss:
best_loss = loss
best_epoch = epoch
if keep_best_weights:
best_weights = deepcopy(self.state_dict())
epoch_diverged = False
if check_for_divergences:
if not torch.isfinite(loss):
epoch_diverged = True
if torch.max(torch.abs(flow_x)) > 1e8:
epoch_diverged = True
elif torch.max(torch.abs(flow_log_prob)) > 1e6:
epoch_diverged = True
elif torch.any(~torch.isfinite(flow_x)):
epoch_diverged = True
elif torch.any(~torch.isfinite(flow_log_prob)):
epoch_diverged = True
n_divergences += epoch_diverged

if not epoch_diverged:
loss.backward()
optimizer.step()
if loss < best_loss:
best_loss = loss
best_epoch = epoch
if keep_best_weights:
best_weights = deepcopy(self.state_dict())
else:
loss = torch.nan

pbar.set_postfix_str(f'Loss: {loss:.4f} [best: {best_loss:.4f} @ {best_epoch}]')
pbar.set_postfix_str(f'Loss: {loss:.4f} [best: {best_loss:.4f} @ {best_epoch}], '
f'divergences: {n_divergences}, '
f'flow log_prob: {flow_log_prob.mean():.2f}, '
f'target log_prob: {target_log_prob_value.mean():.2f}')

if epoch - best_epoch > early_stopping_threshold and early_stopping:
break

if keep_best_weights:
if flow_training_diverged:
self.load_state_dict(initial_weights)
elif keep_best_weights:
self.load_state_dict(best_weights)

self.eval()
Expand Down

0 comments on commit f65fd17

Please sign in to comment.