Skip to content

Commit

Permalink
implement SklearnClassifier.batch_prob_classify
Browse files Browse the repository at this point in the history
And remove SklearnClassifier.classify; already inherited from
ClassifierI.
  • Loading branch information
larsmans committed Jan 13, 2012
1 parent 024c00f commit fb470b2
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions nltk/classify/scikitlearn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from nltk.classify.api import ClassifierI
from nltk.probability import DictionaryProbDist

import numpy as np

Expand All @@ -25,8 +26,10 @@ def batch_classify(self, featuresets):
y = self._clf.predict(X)
return [self._index_label[int(yi)] for yi in y]

def classify(self, featureset):
return self.batch_classify([featureset])
def batch_prob_classify(self, featuresets):
X = self._featuresets_to_array(featuresets)
y_proba = self._clf.predict_proba(X)
return [self._make_probdist(y_proba[i]) for i in xrange(len(y_proba))]

def labels(self):
return self._label_index.keys()
Expand Down Expand Up @@ -72,6 +75,10 @@ def _featuresets_to_array(self, featuresets):

return X

def _make_probdist(self, y_proba):
return DictionaryProbDist(dict((self._index_label[i], p)
for i, p in enumerate(y_proba)))


if __name__ == "__main__":
from nltk.classify.util import names_demo, binary_names_demo_features
Expand Down

0 comments on commit fb470b2

Please sign in to comment.