forked from picopalette/phishing-detection-plugin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdump.py
48 lines (34 loc) · 1.27 KB
/
dump.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# coding: utf-8
# In[9]:
from sklearn.tree import _tree
# In[10]:
def tree_to_json(tree):
tree_ = tree.tree_
feature_names = range(30)
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
def recurse(node):
tree_json = dict()
if tree_.feature[node] != _tree.TREE_UNDEFINED:
tree_json['type'] = 'split'
threshold = tree_.threshold[node]
tree_json['threshold'] = "{} <= {}".format(feature_name[node], threshold)
tree_json['left'] = recurse(tree_.children_left[node])
tree_json['right'] = recurse(tree_.children_right[node])
else:
tree_json['type'] = 'leaf'
tree_json['value'] = tree_.value[node].tolist()
return tree_json
return recurse(0)
# In[11]:
def forest_to_json(forest):
forest_json = dict()
forest_json['n_features'] = forest.n_features_
forest_json['n_classes'] = forest.n_classes_
forest_json['classes'] = forest.classes_.tolist()
forest_json['n_outputs'] = forest.n_outputs_
forest_json['n_estimators'] = forest.n_estimators
forest_json['estimators'] = [tree_to_json(estimator) for estimator in forest.estimators_]
return forest_json