-
Notifications
You must be signed in to change notification settings - Fork 14
/
interpolation_comparison.py
69 lines (42 loc) · 1.68 KB
/
interpolation_comparison.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
from matplotlib import pyplot as plt
import src.utilities as U
from latent_plots import get_models
def create_interpolation_dataset(x, y, encoder, decoder, n=100):
chosen = np.random.permutation(x.shape[0])
x_1 = x[chosen][:n]
y_1 = y[chosen][:n]
# choose datapoints of different classes
x_2 = []
for l in y_1:
feasible = x[y.argmax(axis=1) != l.argmax()]
choice = np.random.randint(feasible.shape[0])
x_2.append(feasible[choice])
x_2 = np.array(x_2)
# generate interpolations in image space
ts = np.ones(n) * 0.5
ts = ts.reshape(n, 1, 1, 1)
pixel_interp = ts * x_1 + (1 - ts) * x_2
z_1 = encoder.predict(x_1)
z_2 = encoder.predict(x_2)
ts = ts.reshape(n, 1)
latent_interp = ts * z_1 + (1 - ts) * z_2
latent_interp = decoder.predict(latent_interp)
labels = np.concatenate([np.ones(n), np.zeros(n)])
xs = np.concatenate([pixel_interp, latent_interp], axis=0)
return xs, labels
if __name__ == '__main__':
model, encoder, decoder = get_models()
# move along a random line in latent space
_, _, mnist, label = U.get_mnist()
x, y = create_interpolation_dataset(mnist, label, encoder, decoder, n=3000)
_, entropy, mi = model.get_results(x)
f, ax = plt.subplots(2, 1)
ax[0].hist(entropy[y == 0], color='r', alpha=0.5, label='Latent Space')
ax[0].hist(entropy[y == 1], color='b', alpha=0.5, label='Pixel Space')
ax[0].legend()
ax[1].hist(mi[y == 0], color='r', alpha=0.5, label='Latent Space')
ax[1].hist(mi[y == 1], color='b', alpha=0.5, label='Pixel Space')
ax[1].legend()
plt.savefig('path-to-my-figure.png')
plt.show()