Skip to content

Commit

Permalink
Add progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 5, 2023
1 parent d73111a commit f0fc2d3
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions normalizing_flows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def base_sample(self, sample_shape):
z = unflatten_event(z_flat, self.bijection.event_shape)
return z



def forward_with_log_prob(self, x: torch.Tensor, context: torch.Tensor = None):
if context is not None:
assert context.shape[0] == x.shape[0]
Expand Down Expand Up @@ -128,16 +126,24 @@ def variational_fit(self,
target,
n_epochs: int = 10,
lr: float = 0.01,
n_samples: int = 1000):
n_samples: int = 1000,
show_progress: bool = False):
# target must have a .sample method that takes as input the batch shape
optimizer = torch.optim.AdamW(self.parameters(), lr=lr)
for i in range(n_epochs):
if show_progress:
iterator = tqdm(range(n_epochs), desc='Variational NF fit')
else:
iterator = range(n_epochs)
for i in iterator:
x_train = target.sample((n_samples,)).to(self.loc.device) # TODO context!
optimizer.zero_grad()
loss = -self.log_prob(x_train).mean()
loss.backward()
optimizer.step()

if show_progress:
iterator.set_postfix_str(f'loss: {float(loss):.4f}')


class DDNF(Flow):
"""
Expand All @@ -158,7 +164,7 @@ def fit(self,
shuffle: bool = True,
show_progress: bool = False,
w_train: torch.Tensor = None,
rec_err_coef:float=1.0):
rec_err_coef: float = 1.0):
"""
:param x_train:
Expand Down

0 comments on commit f0fc2d3

Please sign in to comment.