-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
205 lines (180 loc) · 5.28 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import linecache
import math
import sys
sys.setrecursionlimit(100000) #例如这里设置为十万
from graphviz import Digraph
g = Digraph('G', filename='test.gv')
#声明一个全局字典来存储编号(列号)和节点的映射关系cnt => row
cntToNode={}
rowToName={'0':'Outlook', '1':'Temper.', '2':'Humidity', '3':'Windy'}
nameToRow={'Outlook':'0', 'Temper.':'1', 'Humidity':'2', 'Windy':'3'}
edgeVal = {} #存 u=>v 对应的边的值
nodeVal = {} #存节点u对应的值
#一个节点需要存储的信息有它的邻接节点的编号列表
cnt = 0
dict={}
#测试通过
def read(filename):
dataSet = []
ans = []
with open(filename, 'r+') as file:
for line in file:
tmp = list(line.split(' '))
dataSet.append(tmp)
for line in dataSet:
val = line[-1][0]
line = line[:-1]
line.append(str(val))
ans.append(line)
return ans
def create(dataSet):
uniqueLabel = set([example[-1] for example in dataSet])
if len(uniqueLabel) == 1:
return None
entropy = cal(dataSet)
#选出信息增益最大的特征 输入:数据集 输出特征所在列列号
row = select(dataSet)
sum = 0
global cnt
cnt = cnt + 1
ans = cnt
global cntToNode
cntToNode[cnt] = row
global g
g.node(str(cnt), label=rowToName[str(cntToNode[cnt])])
nodeVal[cnt] = rowToName[str(cntToNode[cnt])]
#根据列号分出子集并递归创建子树
# row = 2
uniqueValue = set([example[row] for example in dataSet])
if len(uniqueValue) <= 1:
return None
edgeValue = None
for value in uniqueValue:
subDataSet = splitDataSet(dataSet, value, row)
tmp = create(subDataSet)
#边和节点怎么存储,怎么连接呢?用字典
if tmp == None:
cnt = cnt + 1
tmp = cnt
label = subDataSet[0][-1]
g.node(str(cnt), label='标签为:'+str(label))
nodeVal[cnt] = str(label)
else:
g.node(str(tmp), label=rowToName[str(cntToNode[tmp])])
nodeVal[tmp] = rowToName[str(cntToNode[tmp])]
add_edge(ans, tmp, value)
return ans
#测试
# print(uniqueValue)
#测试通过
def cal(dataSet):
labelCounts={}
labelSum=len(dataSet)
ans = 0
for line in dataSet:
label = line[-1]
if label not in labelCounts:
labelCounts[label] = 1
else:
labelCounts[label] += 1
for label in labelCounts:
pro = labelCounts[label]/labelSum
ans -= pro * math.log(pro, 2)
# print(ans)
return ans
#测试通过
def select(dataSet):
#每个特征划分
best = 1000000
ans = -1
attrSum = len(dataSet[0]) - 1
attrCounts = {}
for i in range(attrSum):
uniqueValue = set([example[i] for example in dataSet])
curEntropy = 0
for value in uniqueValue:
sum = 0
for line in dataSet:
tmp = line[i]
if value == tmp:
sum += 1
subDataSet = splitDataSet(dataSet, value, i)
pro = sum/len(dataSet)
curEntropy += pro*cal(subDataSet)
# print(curEntropy)
if curEntropy < best:
best = curEntropy
ans = i
return ans
#测试通过
def splitDataSet(dataSet, value, row):
ans = []
for line in dataSet:
if line[row] == value:
# reducedLine = line[:row]
# reducedLine.extend(line[row+1:])
# ans.append(reducedLine)
ans.append(line)
return ans
def add_edge(u, v, value): #u ==> v 并且在边上写value
if u not in dict:
dict[u] = set()
dict[u].add(v)
else:
dict[u].add(v)
# global g
g.edge(str(u), str(v), label=str(value))
edgeVal[str(u)+str(v)] = str(value)
def test():
root = create(read("mytrain.data"))
g.view()
exam(root, read("mytest.data"))
def exam(root, dataSet): #输入:决策树的根节点编号,测试集 输出:错误率
# num = len(dataSet[0])
sum = 0
num = 0
for line in dataSet:
num += 1
if not_ok(line, root):
sum += 1
faultPro = sum/num
print("错误率为:", faultPro)
return faultPro
def not_ok(line, root):
i = 0
row = int(nameToRow[nodeVal[root]])
while i < len(line):
u = root
if len(dict[u]) == 0:
label = nodeVal[u]
if line[-1] == label:
return True
else:
return False
for v in dict[u]:
tmp = str(u)+str(v)
other = line[row]
val = edgeVal[tmp]
if val == line[row]:
root = v
i += 1
attr = nodeVal[v]
if isLabel(attr):
label = line[-1]
x = 0
if attr == label:
x = 1
return False
else:
return True
# return attr == line[-1]
row = int(nameToRow[nodeVal[root]])
break
return True
def isLabel(attr):
if attr == '1' or attr == '0':
return True
return False
if __name__ == '__main__':
test()
# solve()