-
Notifications
You must be signed in to change notification settings - Fork 77
/
alg_comparison.py
62 lines (49 loc) · 1.33 KB
/
alg_comparison.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import print_function
from topologylayer.nn import LevelSetLayer1D
import matplotlib.pyplot as plt
import torch
import time
import numpy as np
def sum_finite(d):
diff = d[:,0] - d[:,1]
inds = diff < np.inf
return torch.sum(diff[inds])
# apparently there is some overhead the first time backward is called.
# we'll just get it over with now.
n = 20
y = torch.rand(n, dtype=torch.float).requires_grad_(True)
layer1 = LevelSetLayer1D(n, False)
dgm, issublevel = layer1(y)
p = sum_finite(dgm[0])
p.backward()
algs = ['hom', 'hom2', 'cohom']
tcs = {}
tfs = {}
tbs = {}
for alg in algs:
tcs[alg] = []
tfs[alg] = []
tbs[alg] = []
ns = [100, 200, 400, 1000, 2000, 4000, 8000, 16000]
for alg in algs:
for n in ns:
y = torch.rand(n, dtype=torch.float).requires_grad_(True)
t0 = time.time()
layer = LevelSetLayer1D(n, False, alg=alg)
ta = time.time() - t0
tcs[alg].append(ta)
t0 = time.time()
dgm, issublevel = layer(y)
ta = time.time() - t0
tfs[alg].append(ta)
p = sum_finite(dgm[0])
t0 = time.time()
p.backward()
ta = time.time() - t0
tbs[alg].append(ta)
for alg in algs:
plt.loglog(ns, tfs[alg], label=alg)
plt.legend()
plt.xlabel("n")
plt.ylabel("forward time")
plt.savefig("alg_time_forward.png")