Skip to content

Commit

Permalink
Merge pull request nltk#1680 from vthorsteinsson/improvement/speedup_…
Browse files Browse the repository at this point in the history
…tnt_tagger

Speed up TnT tagger; improve FreqDist and ConditionalFreqDist
  • Loading branch information
stevenbird authored Apr 13, 2017
2 parents b7c2aff + 38edf1e commit 9370fa3
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 22 deletions.
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
- Prasasto Adi
- Safwan Kamarrudin
- Arthur Tilley
- Vilhjalmur Thorsteinsson

## Others whose work we've taken and included in NLTK, but who didn't directly contribute it:
### Contributors to the Porter Stemmer
Expand Down
42 changes: 39 additions & 3 deletions nltk/probability.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def __init__(self, samples=None):
"""
Counter.__init__(self, samples)

# Cached number of samples in this FreqDist
self._N = None

def N(self):
"""
Return the total number of sample outcomes that have been
Expand All @@ -114,7 +117,38 @@ def N(self):
:rtype: int
"""
return sum(self.values())
if self._N is None:
# Not already cached, or cache has been invalidated
self._N = sum(self.values())
return self._N

def __setitem__(self, key, val):
"""
Override ``Counter.__setitem__()`` to invalidate the cached N
"""
self._N = None
super(FreqDist, self).__setitem__(key, val)

def __delitem__(self, key):
"""
Override ``Counter.__delitem__()`` to invalidate the cached N
"""
self._N = None
super(FreqDist, self).__delitem__(key)

def update(self, *args, **kwargs):
"""
Override ``Counter.update()`` to invalidate the cached N
"""
self._N = None
super(FreqDist, self).update(*args, **kwargs)

def setdefault(self, key, val):
"""
Override ``Counter.setdefault()`` to invalidate the cached N
"""
self._N = None
super(FreqDist, self).setdefault(key, val)

def B(self):
"""
Expand Down Expand Up @@ -192,9 +226,10 @@ def freq(self, sample):
:type sample: any
:rtype: float
"""
if self.N() == 0:
n = self.N()
if n == 0:
return 0
return self[sample] / self.N()
return self[sample] / n

def max(self):
"""
Expand Down Expand Up @@ -1749,6 +1784,7 @@ def __init__(self, cond_samples=None):
:type cond_samples: Sequence of (condition, sample) tuples
"""
defaultdict.__init__(self, FreqDist)

if cond_samples:
for (cond, sample) in cond_samples:
self[cond][sample] += 1
Expand Down
29 changes: 10 additions & 19 deletions nltk/tag/tnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,30 +357,24 @@ def _tagword(self, sent, current_states):
# if word is known
# compute the set of possible tags
# and their associated log probabilities
if word in self._wd.conditions():
if word in self._wd:
self.known += 1

for (history, curr_sent_logprob) in current_states:
logprobs = []

for t in self._wd[word].keys():
p_uni = self._uni.freq((t,C))
p_bi = self._bi[history[-1]].freq((t,C))
p_tri = self._tri[tuple(history[-2:])].freq((t,C))
p_wd = self._wd[word][t] / self._uni[(t,C)]
tC = (t,C)
p_uni = self._uni.freq(tC)
p_bi = self._bi[history[-1]].freq(tC)
p_tri = self._tri[tuple(history[-2:])].freq(tC)
p_wd = self._wd[word][t] / self._uni[tC]
p = self._l1 *p_uni + self._l2 *p_bi + self._l3 *p_tri
p2 = log(p, 2) + log(p_wd, 2)

logprobs.append(((t,C), p2))


# compute the result of appending each tag to this history
for (tag, logprob) in logprobs:
new_states.append((history + [tag],
curr_sent_logprob + logprob))



# compute the result of appending each tag to this history
new_states.append((history + [tC],
curr_sent_logprob + p2))

# otherwise a new word, set of possible tags is unknown
else:
Expand All @@ -398,7 +392,7 @@ def _tagword(self, sent, current_states):
tag = ('Unk',C)

# otherwise apply the unknown word tagger
else :
else:
[(_w, t)] = list(self._unk.tag([word]))
tag = (t,C)

Expand All @@ -407,8 +401,6 @@ def _tagword(self, sent, current_states):

new_states = current_states



# now have computed a set of possible new_states

# sort states by log prob
Expand All @@ -420,7 +412,6 @@ def _tagword(self, sent, current_states):
if len(new_states) > self._N:
new_states = new_states[:self._N]


# compute the tags for the rest of the sentence
# return the best list of tags for the sentence
return self._tagword(sent, new_states)
Expand Down

0 comments on commit 9370fa3

Please sign in to comment.