-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmake_fig18_fig19.py
129 lines (116 loc) · 4.51 KB
/
make_fig18_fig19.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
make_fig18_fig19.py
Produces Figures 18 and 19 in O'Shaughnessy et al., 'Generative causal
explanations of black-box classifiers,' Proc. NeurIPS 2020: final
value of causal effect and data fidelity terms in objective for
various capacities of VAE.
Note: this script creates the file ./results/fig18.mat. The matlab script
make_fig18.m creates the final plots in the paper.
"""
import numpy as np
import torch
import util
import plotting
import matplotlib.pyplot as plt
from GCE import GenerativeCausalExplainer
import os
# --- parameters ---
# dataset
data_classes = np.array([0,3,4])
# classifier
classifier_path = './pretrained_models/fmnist_034_classifier'
# vae
K = 2
L = 4
train_steps = 8000
Nalpha = 100
Nbeta = 25
lam = 0.05
batch_size = 32
lr = 1e-4
filts_per_layer = [4,8,16,32,48,64]
lambdas = np.logspace(-3,-1,10)
# other
randseed = 0
gce_path = './pretrained_models/vae_capacity'
retrain_gce = True
save_gce = True
# --- initialize ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if randseed is not None:
np.random.seed(randseed)
torch.manual_seed(randseed)
# --- load data ---
from load_mnist import load_fashion_mnist_classSelect
X, Y, tridx = load_fashion_mnist_classSelect('train', data_classes,
range(0,len(data_classes)))
vaX, vaY, vaidx = load_fashion_mnist_classSelect('val', data_classes,
range(0,len(data_classes)))
ntrain, nrow, ncol, c_dim = X.shape
x_dim = nrow*ncol
# --- load classifier ---
from models.CNN_classifier import CNN
classifier = CNN(len(data_classes)).to(device)
checkpoint = torch.load('%s/model.pt' % classifier_path,
map_location=device)
classifier.load_state_dict(checkpoint['model_state_dict_classifier'])
# --- initialize VAE and train GCE ---
from models.CVAE import Decoder, Encoder
data = {
'loss' : np.zeros((len(filts_per_layer),len(lambdas),train_steps)),
'loss_ce' : np.zeros((len(filts_per_layer),len(lambdas),train_steps)),
'loss_nll' : np.zeros((len(filts_per_layer),len(lambdas),train_steps)),
'Ijoint' : np.zeros((len(filts_per_layer),len(lambdas))),
'Is' : np.zeros((len(filts_per_layer),len(lambdas),K+L))}
for (i_f, nfilt) in enumerate(filts_per_layer):
for (i_l, lam) in enumerate(lambdas):
filename = 'model_%dfilters_lambda%g.pt' % (nfilt, lam)
if retrain_gce:
print('=== %d FILTERS PER LAYER, LAMBDA = %g ===' % (nfilt, lam))
# initialize VAE
encoder = Encoder(K+L, c_dim, x_dim,
filt_per_layer=nfilt).to(device)
decoder = Decoder(K+L, c_dim, x_dim,
filt_per_layer=nfilt).to(device)
encoder.apply(util.weights_init_normal)
decoder.apply(util.weights_init_normal)
# train GCE
gce = GenerativeCausalExplainer(classifier, decoder, encoder,
device, debug_print=False)
traininfo = gce.train(X, K, L,
steps=train_steps,
Nalpha=Nalpha,
Nbeta=Nbeta,
lam=lam,
batch_size=batch_size,
lr=lr)
if save_gce:
if not os.path.exists(gce_path):
os.makedirs(gce_path)
torch.save((gce, traininfo), os.path.join(gce_path, filename))
else: # load pretrained model
gce, traininfo = torch.load(os.path.join(gce_path, filename))
# get data
gce.encoder.eval()
gce.decoder.eval()
torch.cuda.empty_cache()
data['loss'][i_f,i_l,:] = traininfo['loss']
data['loss_ce'][i_f,i_l,:] = traininfo['loss_ce']
data['loss_nll'][i_f,i_l,:] = traininfo['loss_nll']
data['Ijoint'][i_f,i_l] = gce.informationFlow()
data['Is'][i_f,i_l,:] = gce.informationFlow_singledim(dims=range(K+L))
# save figures for explanation
sample_ind = np.concatenate((np.where(vaY == 0)[0][:3],
np.where(vaY == 1)[0][:3],
np.where(vaY == 2)[0][:2]))
x = torch.from_numpy(vaX[sample_ind])
zs_sweep = [-3., -2., -1., 0., 1., 2., 3.]
Xhats, yhats = gce.explain(x, zs_sweep)
if not os.path.exists('./figs/fig19/'):
os.makedirs('./figs/fig19/')
plotting.plotExplanation(1.-Xhats, yhats,
save_path='./figs/fig19/%dfilters_lambda%g'%(nfilt,lam))
plt.close('all')
# save all results to file
from scipy.io import savemat
savemat('./results/fig18.mat', {'data' : data})