diff --git a/nltk/classify/scikitlearn.py b/nltk/classify/scikitlearn.py index 1854038d4d..144d67003e 100644 --- a/nltk/classify/scikitlearn.py +++ b/nltk/classify/scikitlearn.py @@ -1,4 +1,11 @@ +# Natural Language Toolkit: Interface to scikit-learn classifiers +# +# Author: Lars Buitinck +# URL: +# For license information, see LICENSE.TXT + from nltk.classify.api import ClassifierI +from nltk.probability import DictionaryProbDist import numpy as np @@ -25,8 +32,10 @@ def batch_classify(self, featuresets): y = self._clf.predict(X) return [self._index_label[int(yi)] for yi in y] - def classify(self, featureset): - return self.batch_classify([featureset]) + def batch_prob_classify(self, featuresets): + X = self._featuresets_to_array(featuresets) + y_proba = self._clf.predict_proba(X) + return [self._make_probdist(y_proba[i]) for i in xrange(len(y_proba))] def labels(self): return self._label_index.keys() @@ -72,6 +81,10 @@ def _featuresets_to_array(self, featuresets): return X + def _make_probdist(self, y_proba): + return DictionaryProbDist(dict((self._index_label[i], p) + for i, p in enumerate(y_proba))) + if __name__ == "__main__": from nltk.classify.util import names_demo, binary_names_demo_features