Skip to content

Commit

Permalink
Merge pull request nltk#21 from larsmans/sklearn
Browse files Browse the repository at this point in the history
implement SklearnClassifier.batch_prob_classify
  • Loading branch information
stevenbird committed Jan 13, 2012
2 parents 024c00f + 6d74d57 commit 821a9c5
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions nltk/classify/scikitlearn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# Natural Language Toolkit: Interface to scikit-learn classifiers
#
# Author: Lars Buitinck <[email protected]>
# URL: <http://www.nltk.org/>
# For license information, see LICENSE.TXT

from nltk.classify.api import ClassifierI
from nltk.probability import DictionaryProbDist

import numpy as np

Expand All @@ -25,8 +32,10 @@ def batch_classify(self, featuresets):
y = self._clf.predict(X)
return [self._index_label[int(yi)] for yi in y]

def classify(self, featureset):
return self.batch_classify([featureset])
def batch_prob_classify(self, featuresets):
X = self._featuresets_to_array(featuresets)
y_proba = self._clf.predict_proba(X)
return [self._make_probdist(y_proba[i]) for i in xrange(len(y_proba))]

def labels(self):
return self._label_index.keys()
Expand Down Expand Up @@ -72,6 +81,10 @@ def _featuresets_to_array(self, featuresets):

return X

def _make_probdist(self, y_proba):
return DictionaryProbDist(dict((self._index_label[i], p)
for i, p in enumerate(y_proba)))


if __name__ == "__main__":
from nltk.classify.util import names_demo, binary_names_demo_features
Expand Down

0 comments on commit 821a9c5

Please sign in to comment.