Skip to content

Commit

Permalink
Remove separate DDNF class
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Feb 9, 2024
1 parent df810da commit 92b2518
Showing 1 changed file with 0 additions and 74 deletions.
74 changes: 0 additions & 74 deletions normalizing_flows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,77 +291,3 @@ def variational_fit(self,
loss.backward()
optimizer.step()
iterator.set_postfix_str(f'Variational loss: {loss:.4f}')


class DDNF(Flow):
"""
Deep diffeomorphic normalizing flow.
Salman et al. Deep diffeomorphic normalizing flows (2018).
"""

def __init__(self, event_shape: torch.Size, **kwargs):
bijection = DeepDiffeomorphicBijection(event_shape=event_shape, **kwargs)
super().__init__(bijection)

def fit(self,
x_train: torch.Tensor,
n_epochs: int = 500,
lr: float = 0.05,
batch_size: int = 1024,
shuffle: bool = True,
show_progress: bool = False,
w_train: torch.Tensor = None,
rec_err_coef: float = 1.0):
"""
:param x_train:
:param n_epochs:
:param lr: learning rate. In general, lower learning rates are recommended for high-parametric bijections.
:param batch_size:
:param shuffle:
:param show_progress:
:param w_train: training data weights
:param rec_err_coef: reconstruction error regularization coefficient.
:return:
"""
if w_train is None:
batch_shape = get_batch_shape(x_train, self.bijection.event_shape)
w_train = torch.ones(batch_shape)
if batch_size is None:
batch_size = len(x_train)
optimizer = torch.optim.AdamW(self.parameters(), lr=lr)
dataset = TensorDataset(x_train, w_train)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

n_event_dims = int(torch.prod(torch.as_tensor(self.bijection.event_shape)))

if show_progress:
iterator = tqdm(range(n_epochs), desc='Fitting NF')
else:
iterator = range(n_epochs)

for _ in iterator:
for batch_x, batch_w in data_loader:
optimizer.zero_grad()

z, log_prob = self.forward_with_log_prob(batch_x.to(self.loc)) # TODO context!
w = batch_w.to(self.loc)
assert log_prob.shape == w.shape
loss = -torch.mean(log_prob * w) / n_event_dims

if hasattr(self.bijection, 'regularization'):
# Always true for DeepDiffeomorphicBijection, but we keep it for clarity
loss += self.bijection.regularization()

# Inverse consistency regularization
x_reconstructed = self.bijection.inverse(z)
loss += reconstruction_error(batch_x, x_reconstructed, self.bijection.event_shape, rec_err_coef)

# Geodesic regularization

loss.backward()
optimizer.step()

if show_progress:
iterator.set_postfix_str(f'Loss: {loss:.4f}')

0 comments on commit 92b2518

Please sign in to comment.