Skip to content

Commit

Permalink
Added _N cache in FreqDist; removed freeze_N()
Browse files Browse the repository at this point in the history
  • Loading branch information
vthorsteinsson committed Apr 10, 2017
1 parent 957ec3b commit 38edf1e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
55 changes: 33 additions & 22 deletions nltk/probability.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def __init__(self, samples=None):
"""
Counter.__init__(self, samples)

# Cached number of samples in this FreqDist
self._N = None

def N(self):
"""
Return the total number of sample outcomes that have been
Expand All @@ -114,18 +117,38 @@ def N(self):
:rtype: int
"""
return sum(self.values())
if self._N is None:
# Not already cached, or cache has been invalidated
self._N = sum(self.values())
return self._N

def freeze_N(self):
def __setitem__(self, key, val):
"""
Set N permanently to its current value, making subsequent
calls to ``FreqDist.N()`` much faster. Use this for instance
after training a tagger, but before running it. After calling
``FreqDist.freeze_N()``, no more samples should be added to the
FreqDist object.
Override ``Counter.__setitem__()`` to invalidate the cached N
"""
n = self.N()
self.N = lambda: n
self._N = None
super(FreqDist, self).__setitem__(key, val)

def __delitem__(self, key):
"""
Override ``Counter.__delitem__()`` to invalidate the cached N
"""
self._N = None
super(FreqDist, self).__delitem__(key)

def update(self, *args, **kwargs):
"""
Override ``Counter.update()`` to invalidate the cached N
"""
self._N = None
super(FreqDist, self).update(*args, **kwargs)

def setdefault(self, key, val):
"""
Override ``Counter.setdefault()`` to invalidate the cached N
"""
self._N = None
super(FreqDist, self).setdefault(key, val)

def B(self):
"""
Expand Down Expand Up @@ -1761,6 +1784,7 @@ def __init__(self, cond_samples=None):
:type cond_samples: Sequence of (condition, sample) tuples
"""
defaultdict.__init__(self, FreqDist)

if cond_samples:
for (cond, sample) in cond_samples:
self[cond][sample] += 1
Expand Down Expand Up @@ -1790,19 +1814,6 @@ def N(self):
"""
return sum(fdist.N() for fdist in compat.itervalues(self))

def freeze_N(self):
"""
Set N permanently to its current value, making subsequent
calls to ``ConditionalFreqDist.N()`` much faster. Use this for
instance after training a tagger, but before running it. After
calling ``ConditionalFreqDist.freeze_N()``, no more samples
should be added to the ConditionalFreqDist object.
"""
for fdist in compat.itervalues(self):
fdist.freeze_N()
n = self.N()
self.N = lambda: n

def plot(self, *args, **kwargs):
"""
Plot the given samples from the conditional frequency distribution.
Expand Down
7 changes: 0 additions & 7 deletions nltk/tag/tnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,6 @@ def train(self, data):
#print "lambdas"
#print i, self._l1, i, self._l2, i, self._l3

# after training, freeze the frequency distributions
# to make subsequent frequency queries faster
self._uni.freeze_N()
self._bi.freeze_N()
self._tri.freeze_N()
self._wd.freeze_N()


def _compute_lambda(self):
'''
Expand Down

0 comments on commit 38edf1e

Please sign in to comment.