diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index f50b1b8..43eb35d 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -190,6 +190,9 @@ def variational_fit(self, n_epochs: int = 500, lr: float = 0.05, n_samples: int = 1000, + early_stopping: bool = False, + early_stopping_threshold: int = 50, + keep_best_weights: bool = True, show_progress: bool = False): """ Train a distribution with stochastic variational inference. @@ -210,15 +213,31 @@ def variational_fit(self, """ iterator = tqdm(range(n_epochs), desc='Fitting with SVI', disable=not show_progress) optimizer = torch.optim.AdamW(self.parameters(), lr=lr) + best_loss = torch.inf + best_epoch = 0 + best_weights = deepcopy(self.state_dict()) - for _ in iterator: + for epoch in iterator: optimizer.zero_grad() flow_x, flow_log_prob = self.sample(n_samples, return_log_prob=True) loss = -torch.mean(target_log_prob(flow_x) + flow_log_prob) loss += self.regularization() loss.backward() optimizer.step() - iterator.set_postfix_str(f'Loss: {loss:.4f}') + + if loss < best_loss: + best_loss = loss + best_epoch = epoch + if keep_best_weights: + best_weights = deepcopy(self.state_dict()) + + iterator.set_postfix_str(f'Loss: {loss:.4f} [best: {best_loss:.4f} @ {best_epoch}]') + + if epoch - best_epoch > early_stopping_threshold and early_stopping: + break + + if keep_best_weights: + self.load_state_dict(best_weights) class Flow(BaseFlow): @@ -288,9 +307,11 @@ def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, re if no_grad: z = z.detach() with torch.no_grad(): - x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), context=context) + x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), + context=context) else: - x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), context=context) + x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), + context=context) x = x.to(self.get_device()) if return_log_prob: