-
Notifications
You must be signed in to change notification settings - Fork 15
/
tree.py
128 lines (102 loc) · 3.28 KB
/
tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import collections
UNK = 'UNK'
class Node:
def __init__(self,label,word=None):
self.label = label
self.word = word
self.parent = None
self.left = None
self.right = None
self.isLeaf = False
self.fprop = False
class Tree:
def __init__(self,treeString,openChar='(',closeChar=')'):
tokens = []
self.open = '('
self.close = ')'
for toks in treeString.strip().split():
tokens += list(toks)
self.root = self.parse(tokens)
def parse(self, tokens, parent=None):
assert tokens[0] == self.open, "Malformed tree"
assert tokens[-1] == self.close, "Malformed tree"
split = 2 # position after open and label
countOpen = countClose = 0
if tokens[split] == self.open:
countOpen += 1
split += 1
# Find where left child and right child split
while countOpen != countClose:
if tokens[split] == self.open:
countOpen += 1
if tokens[split] == self.close:
countClose += 1
split += 1
# New node
node = Node(int(tokens[1])-1) # zero index labels
node.parent = parent
# leaf Node
if countOpen == 0:
node.word = ''.join(tokens[2:-1]).lower() # lower case?
node.isLeaf = True
return node
node.left = self.parse(tokens[2:split],parent=node)
node.right = self.parse(tokens[split:-1],parent=node)
return node
def leftTraverse(root,nodeFn=None,args=None):
"""
Recursive function traverses tree
from left to right.
Calls nodeFn at each node
"""
nodeFn(root,args)
if root.left is not None:
leftTraverse(root.left,nodeFn,args)
if root.right is not None:
leftTraverse(root.right,nodeFn,args)
def countWords(node,words):
if node.isLeaf:
words[node.word] += 1
def mapWords(node,wordMap):
if node.isLeaf:
if node.word not in wordMap:
node.word = wordMap[UNK]
else:
node.word = wordMap[node.word]
def loadWordMap():
import cPickle as pickle
with open('wordMap.bin','r') as fid:
return pickle.load(fid)
def buildWordMap():
"""
Builds map of all words in training set
to integer values.
"""
import cPickle as pickle
file = 'trees/train.txt'
print "Reading trees.."
with open(file,'r') as fid:
trees = [Tree(l) for l in fid.readlines()]
print "Counting words.."
words = collections.defaultdict(int)
for tree in trees:
leftTraverse(tree.root,nodeFn=countWords,args=words)
wordMap = dict(zip(words.iterkeys(),xrange(len(words))))
wordMap[UNK] = len(words) # Add unknown as word
with open('wordMap.bin','w') as fid:
pickle.dump(wordMap,fid)
def loadTrees(dataSet='train'):
"""
Loads training trees. Maps leaf node words to word ids.
"""
wordMap = loadWordMap()
file = 'trees/%s.txt'%dataSet
print "Reading trees.."
with open(file,'r') as fid:
trees = [Tree(l) for l in fid.readlines()]
for tree in trees:
leftTraverse(tree.root,nodeFn=mapWords,args=wordMap)
return trees
if __name__=='__main__':
buildWordMap()
train = loadTrees()