-
Notifications
You must be signed in to change notification settings - Fork 0
/
E_MOA_2P.py
60 lines (43 loc) · 1.39 KB
/
E_MOA_2P.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
Plot E2 - Visualize classification results for MOA streams
"""
import numpy as np
import matplotlib.pyplot as plt
import os
np.random.seed(1233)
measures = ["clustering",
"complexity",
"concept",
"general",
"info-theory",
"itemset",
"landmarking",
"model-based",
"statistical"
]
streams = os.listdir('data/moa')
streams.remove('.DS_Store')
print(streams)
base_clfs = ['GNB', 'KNN', 'SVM', 'DT', 'MLP']
n_drift_types=3
stream_reps=5
res = np.load('results/moa_clf.npy') # measures, datasets, reps, folds, clfs
print(res.shape)
res = res.reshape(9,4,3,10,5)
res_mean = np.mean(res, axis=(2,3))
fig, ax = plt.subplots(2, 2, figsize=(8,10), sharex=True, sharey=True)
ax=ax.ravel()
plt.suptitle('MOA', fontsize=18, y=0.99)
for dataset_id, dataset in enumerate(['RBF', 'LED', 'HYPERPLANE', 'SEA']):
axx = ax[dataset_id]
r = res_mean[:,dataset_id]
axx.imshow(r, vmin=0.4, vmax=1., cmap='Blues')
for _a, __a in enumerate(measures):
for _b, __b in enumerate(base_clfs):
axx.text(_b, _a, "%.3f" % (r[_a, _b]) , va='center', ha='center', c='black' if r[_a, _b]<0.75 else 'white', fontsize=11)
axx.set_title(dataset.split('.')[0])
axx.set_xticks(np.arange(len(base_clfs)),base_clfs)
axx.set_yticks(np.arange(len(measures)),measures)
plt.tight_layout()
plt.savefig('foo.png')
plt.savefig('figures/fig_clf/MOA.png')