From f0fc2d3c7faffab38e3cb15edf5183ed263397ab Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 5 Oct 2023 10:58:26 +0200 Subject: [PATCH] Add progress bar --- normalizing_flows/flows.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index 9bb5453..4c28642 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -29,8 +29,6 @@ def base_sample(self, sample_shape): z = unflatten_event(z_flat, self.bijection.event_shape) return z - - def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None): if context is not None: assert context.shape[0] == x.shape[0] @@ -128,16 +126,24 @@ def variational_fit(self, target, n_epochs: int = 10, lr: float = 0.01, - n_samples: int = 1000): + n_samples: int = 1000, + show_progress: bool = False): # target must have a .sample method that takes as input the batch shape optimizer = torch.optim.AdamW(self.parameters(), lr=lr) - for i in range(n_epochs): + if show_progress: + iterator = tqdm(range(n_epochs), desc='Variational NF fit') + else: + iterator = range(n_epochs) + for i in iterator: x_train = target.sample((n_samples,)).to(self.loc.device) # TODO context! optimizer.zero_grad() loss = -self.log_prob(x_train).mean() loss.backward() optimizer.step() + if show_progress: + iterator.set_postfix_str(f'loss: {float(loss):.4f}') + class DDNF(Flow): """ @@ -158,7 +164,7 @@ def fit(self, shuffle: bool = True, show_progress: bool = False, w_train: torch.Tensor = None, - rec_err_coef:float=1.0): + rec_err_coef: float = 1.0): """ :param x_train: