-
Notifications
You must be signed in to change notification settings - Fork 77
/
holes.py
31 lines (28 loc) · 902 Bytes
/
holes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Example from README
from topologylayer.nn import AlphaLayer, BarcodePolyFeature
import torch, numpy as np, matplotlib.pyplot as plt
# random pointcloud
np.random.seed(0)
data = np.random.rand(100, 2)
# optimization to increase size of holes
layer = AlphaLayer(maxdim=1)
x = torch.autograd.Variable(torch.tensor(data).type(torch.float), requires_grad=True)
f1 = BarcodePolyFeature(1,2,0)
optimizer = torch.optim.Adam([x], lr=1e-2)
for i in range(100):
optimizer.zero_grad()
loss = -f1(layer(x))
loss.backward()
optimizer.step()
# save figure
y = x.detach().numpy()
fig, ax = plt.subplots(ncols=2, figsize=(10,5))
ax[0].scatter(data[:,0], data[:,1])
ax[0].set_title("Before")
ax[1].scatter(y[:,0], y[:,1])
ax[1].set_title("After")
for i in range(2):
ax[i].set_yticklabels([])
ax[i].set_xticklabels([])
ax[i].tick_params(bottom=False, left=False)
plt.savefig('holes.png')