From fb470b2c103f220fed396bd285424706c0c716f3 Mon Sep 17 00:00:00 2001 From: Lars Buitinck Date: Fri, 13 Jan 2012 22:43:30 +0100 Subject: [PATCH 1/2] implement SklearnClassifier.batch_prob_classify And remove SklearnClassifier.classify; already inherited from ClassifierI. --- nltk/classify/scikitlearn.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/nltk/classify/scikitlearn.py b/nltk/classify/scikitlearn.py index 1854038d4d..97bc33985e 100644 --- a/nltk/classify/scikitlearn.py +++ b/nltk/classify/scikitlearn.py @@ -1,4 +1,5 @@ from nltk.classify.api import ClassifierI +from nltk.probability import DictionaryProbDist import numpy as np @@ -25,8 +26,10 @@ def batch_classify(self, featuresets): y = self._clf.predict(X) return [self._index_label[int(yi)] for yi in y] - def classify(self, featureset): - return self.batch_classify([featureset]) + def batch_prob_classify(self, featuresets): + X = self._featuresets_to_array(featuresets) + y_proba = self._clf.predict_proba(X) + return [self._make_probdist(y_proba[i]) for i in xrange(len(y_proba))] def labels(self): return self._label_index.keys() @@ -72,6 +75,10 @@ def _featuresets_to_array(self, featuresets): return X + def _make_probdist(self, y_proba): + return DictionaryProbDist(dict((self._index_label[i], p) + for i, p in enumerate(y_proba))) + if __name__ == "__main__": from nltk.classify.util import names_demo, binary_names_demo_features From 6d74d5780f9183f1a5034d127589a3218987c21a Mon Sep 17 00:00:00 2001 From: Lars Buitinck Date: Fri, 13 Jan 2012 22:46:04 +0100 Subject: [PATCH 2/2] add credits to classify/scikitlearn.py --- nltk/classify/scikitlearn.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nltk/classify/scikitlearn.py b/nltk/classify/scikitlearn.py index 97bc33985e..144d67003e 100644 --- a/nltk/classify/scikitlearn.py +++ b/nltk/classify/scikitlearn.py @@ -1,3 +1,9 @@ +# Natural Language Toolkit: Interface to scikit-learn classifiers +# +# Author: Lars Buitinck +# URL: +# For license information, see LICENSE.TXT + from nltk.classify.api import ClassifierI from nltk.probability import DictionaryProbDist