-
Notifications
You must be signed in to change notification settings - Fork 371
/
Copy pathtreeplotter.py
85 lines (76 loc) · 3.3 KB
/
treeplotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import matplotlib.pyplot as plt
import trees
decisionNode = dict(boxstyle = 'sawtooth',fc = '0.8')
leafNode = dict(boxstyle = 'round4',fc = '0.8')
arrow_args = dict(arrowstyle = "<-")
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.axl.annotate(nodeTxt,xy = parentPt,xycoords = 'axes fraction',xytext = centerPt,
textcoords = 'axes fraction',va = 'center',ha = 'center',bbox = nodeType,arrowprops = arrow_args)
def createPlot():
fig = plt.figure(1,facecolor = 'white')
fig.clf()
createPlot.axl = plt.subplot(111,frameon = False)
plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode)
plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)
plt.show()
def getNumLeafs(myTree):
numleafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numleafs += getNumLeafs(secondDict[key])
else:
numleafs += 1
return numleafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTree = [{'no surfacing':{0:'no',1:{
'flippers':{0:'no',1:'yes'}},3:'maybe'}},
{'no surfacing':{0:'no',1:{'flippers':
{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
]
return listOfTree[i]
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0]-cntrPt[0]/2.0 + cntrPt[0])
yMid = (parentPt[1]-cntrPt[1]/2.0 + cntrPt[1])
createPlot.axl.text(xMid,yMid,txtString,va = "center")
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[0] #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1,facecolor = 'white')
fig.clf()
axprops = dict(xticks = [],yticks = [])
createPlot.axl = plt.subplot(111,frameon = False,**axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalW = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()