-
Notifications
You must be signed in to change notification settings - Fork 0
/
prior_plot.py
88 lines (68 loc) · 2.06 KB
/
prior_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import matplotlib.pyplot as plt
import config
import strlearn as sl
from scipy.ndimage import gaussian_filter1d
weights = config.str_weights()
weigths_names = config.str_weights_names()
borders = config.borders()
criteria = config.criteria()
base_clfs = config.base_clfs()
base_clfs_names = config.base_clf_names()
reps = 10
chunks = 100
pe=3
np.random.seed(1231)
random_states = np.random.randint(0,100000,reps)
str_static = config.str_static()
priors=[]
for r in random_states:
for w_id, w in enumerate(weights):
config = {
**str_static,
**weights[w],
'random_state': random_states[0]
}
config['n_chunks'] = chunks
stream = sl.streams.StreamGenerator(**config)
s_priors = []
for i in range(chunks):
unique, counts = np.unique(stream.get_chunk()[1], return_counts=True)
if len(unique) == 1:
if unique[0] == 0:
s_priors.append(1.)
else:
s_priors.append(0.)
else:
s_priors.append(counts[0] / 200)
priors.append(s_priors)
priors = np.array(priors).reshape(reps,9,chunks)
priors = priors[:,:,1:]
print(priors.shape)
mean_priors = np.mean(priors, axis=0)
# np.save('priors', mean_priors)
fig, axx = plt.subplots(3,1,figsize=(6, 6*1.618), sharex=True, sharey=True)
axx = axx.ravel()
t = ['SIS', 'CDIS', 'DDIS']
cols=['dodgerblue', 'orange', 'tomato']
lss=['-', '--', ':']
c = 0
for i in range(3):
ax = axx[i]
ax.set_title(t[i])
ax.set_ylabel('prior probability')
ax.set_xlabel('chunk id')
ax.set_ylim(0,1)
for j in range(3):
ax.plot(mean_priors[c], ls=lss[c%3], color='black', label=weigths_names[c])
c+=1
ax.legend(loc=1, frameon=True)
ax.grid(ls=":")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlim(0,chunks)
ax.set_ylim(0,1)
plt.tight_layout()
plt.savefig('figures/priors.png')
plt.savefig('figures/priors.eps')
plt.savefig('foo.png')