-
Notifications
You must be signed in to change notification settings - Fork 12
/
causaleffect.py
executable file
·87 lines (84 loc) · 3.32 KB
/
causaleffect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
import torch
"""
joint_uncond:
Sample-based estimate of "joint, unconditional" causal effect, -I(alpha; Yhat).
Inputs:
- params['Nalpha'] monte-carlo samples per causal factor
- params['Nbeta'] monte-carlo samples per noncausal factor
- params['K'] number of causal factors
- params['L'] number of noncausal factors
- params['M'] number of classes (dimensionality of classifier output)
- decoder
- classifier
- device
Outputs:
- negCausalEffect (sample-based estimate of -I(alpha; Yhat))
- info['xhat']
- info['yhat']
"""
def joint_uncond(params, decoder, classifier, device):
eps = 1e-8
I = 0.0
q = torch.zeros(params['M']).to(device)
zs = np.zeros((params['Nalpha']*params['Nbeta'], params['z_dim']))
for i in range(0, params['Nalpha']):
alpha = np.random.randn(params['K'])
zs = np.zeros((params['Nbeta'],params['z_dim']))
for j in range(0, params['Nbeta']):
beta = np.random.randn(params['L'])
zs[j,:params['K']] = alpha
zs[j,params['K']:] = beta
# decode and classify batch of Nbeta samples with same alpha
xhat = decoder(torch.from_numpy(zs).float().to(device))
yhat = classifier(xhat)[0]
p = 1./float(params['Nbeta']) * torch.sum(yhat,0) # estimate of p(y|alpha)
I = I + 1./float(params['Nalpha']) * torch.sum(torch.mul(p, torch.log(p+eps)))
q = q + 1./float(params['Nalpha']) * p # accumulate estimate of p(y)
I = I - torch.sum(torch.mul(q, torch.log(q+eps)))
negCausalEffect = -I
info = {"xhat" : xhat, "yhat" : yhat}
return negCausalEffect, info
"""
joint_uncond_singledim:
Sample-based estimate of "joint, unconditional" causal effect
for single latent factor, -I(z_i; Yhat). Note the interpretation
of params['Nalpha'] and params['Nbeta'] here: Nalpha is the number
of samples of z_i, and Nbeta is the number of samples of the other
latent factors.
Inputs:
- params['Nalpha']
- params['Nbeta']
- params['K']
- params['L']
- params['M']
- decoder
- classifier
- device
- dim (i : compute -I(z_i; Yhat) **note: i is zero-indexed!**)
Outputs:
- negCausalEffect (sample-based estimate of -I(z_i; Yhat))
- info['xhat']
- info['yhat']
"""
def joint_uncond_singledim(params, decoder, classifier, device, dim):
eps = 1e-8
I = 0.0
q = torch.zeros(params['M']).to(device)
zs = np.zeros((params['Nalpha']*params['Nbeta'], params['z_dim']))
for i in range(0, params['Nalpha']):
z_fix = np.random.randn(1)
zs = np.zeros((params['Nbeta'],params['z_dim']))
for j in range(0, params['Nbeta']):
zs[j,:] = np.random.randn(params['K']+params['L'])
zs[j,dim] = z_fix
# decode and classify batch of Nbeta samples with same alpha
xhat = decoder(torch.from_numpy(zs).float().to(device))
yhat = classifier(xhat)[0]
p = 1./float(params['Nbeta']) * torch.sum(yhat,0) # estimate of p(y|alpha)
I = I + 1./float(params['Nalpha']) * torch.sum(torch.mul(p, torch.log(p+eps)))
q = q + 1./float(params['Nalpha']) * p # accumulate estimate of p(y)
I = I - torch.sum(torch.mul(q, torch.log(q+eps)))
negCausalEffect = -I
info = {"xhat" : xhat, "yhat" : yhat}
return negCausalEffect, info