Skip to content

Commit

Permalink
Add architectures.py, add flow speed test
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Nov 7, 2023
1 parent 3cdb1d0 commit 02ec3cb
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
20 changes: 20 additions & 0 deletions normalizing_flows/architectures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from normalizing_flows.bijections.finite.autoregressive.architectures import (
NICE,
RealNVP,
MAF,
IAF,
CouplingRQNSF,
MaskedAutoregressiveRQNSF,
InverseAutoregressiveRQNSF,
CouplingLRS,
MaskedAutoregressiveLRS,
CouplingDSF,
UMNNMAF
)

from normalizing_flows.bijections.continuous.ddnf import DeepDiffeomorphicBijection
from normalizing_flows.bijections.continuous.rnode import RNODE
from normalizing_flows.bijections.continuous.ffjord import FFJORD
from normalizing_flows.bijections.continuous.otflow import OTFlow

from normalizing_flows.bijections.finite.residual.architectures import ResFlow, ProximalResFlow, InvertibleResNet
92 changes: 92 additions & 0 deletions speed_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Test the speed of standard NF operations

import torch
import timeit
import matplotlib.pyplot as plt

from normalizing_flows import Flow
from normalizing_flows.architectures import (
NICE,
RealNVP,
MAF,
IAF,
CouplingRQNSF,
MaskedAutoregressiveRQNSF,
InverseAutoregressiveRQNSF,
CouplingLRS,
MaskedAutoregressiveLRS,
CouplingDSF,
UMNNMAF,
DeepDiffeomorphicBijection,
RNODE,
FFJORD,
OTFlow,
ResFlow,
ProximalResFlow,
InvertibleResNet
)


def avg_eval_time(flow: Flow, n_repeats: int = 30):
total_time = timeit.timeit(lambda: flow.log_prob(x), number=n_repeats)
return total_time / n_repeats


def avg_sampling_time(flow: Flow, batch_size: int = 100, n_repeats: int = 30):
total_time = timeit.timeit(lambda: flow.sample(batch_size), number=n_repeats)
return total_time / n_repeats


if __name__ == '__main__':
torch.manual_seed(0)
batch_shape = (100,)
event_shape = (50,)
x = torch.randn(*batch_shape, *event_shape)

eval_times = {}
sample_times = {}
for bijection_class in [
NICE,
RealNVP,
MAF,
IAF,
CouplingRQNSF,
MaskedAutoregressiveRQNSF,
InverseAutoregressiveRQNSF,
CouplingLRS,
MaskedAutoregressiveLRS,
CouplingDSF,
# UMNNMAF, # Too slow
DeepDiffeomorphicBijection,
RNODE,
FFJORD,
OTFlow,
ResFlow,
ProximalResFlow,
InvertibleResNet
]:
f = Flow(bijection_class(event_shape))

name = bijection_class.__name__
e_avg = avg_eval_time(f)
s_avg = avg_sampling_time(f)

print(f'{name:<30}\t| e: {e_avg:.4f}\t| s: {s_avg:.4f}')
eval_times[name] = e_avg
sample_times[name] = s_avg

plt.figure()
plt.bar(list(eval_times.keys()), list(eval_times.values()))
plt.ylabel("log_prob time [s]")
plt.xlabel("Bijection")
plt.xticks(rotation=30)
plt.tight_layout()
plt.show()

plt.figure()
plt.bar(list(sample_times.keys()), list(sample_times.values()))
plt.ylabel("Sampling time [s]")
plt.xlabel("Bijection")
plt.xticks(rotation=30)
plt.tight_layout()
plt.show()

0 comments on commit 02ec3cb

Please sign in to comment.