Skip to content

Commit

Permalink
Adds fit_params to StackingCVClassifier fit method
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Oct 20, 2017
1 parent 2b3668b commit 7a67ce5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 10 deletions.
42 changes: 33 additions & 9 deletions mlxtend/classifier/stacking_cv_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,29 +111,36 @@ def __init__(self, classifiers, meta_classifier,
self.stratify = stratify
self.shuffle = shuffle

def fit(self, X, y, groups=None):
def fit(self, X, y, groups=None, **fit_params):
""" Fit ensemble classifers and the meta-classifier.
Parameters
----------
X : numpy array, shape = [n_samples, n_features]
Training vectors, where n_samples is the number of samples and
n_features is the number of features.
y : numpy array, shape = [n_samples]
Target values.
groups : numpy array/None, shape = [n_samples]
The group that each sample belongs to. This is used by specific
folding strategies such as GroupKFold()
fit_params : dict, optional
Parameters to pass to the fit methods of `classifiers` and
`meta_classifier`. Note that only fit parameters for `classifiers`
that are the same for each cross-validation split are supported
(e.g. `sample_weight` is not currently supported).
Returns
-------
self : object
"""
self.clfs_ = [clone(clf) for clf in self.classifiers]
self.named_clfs_ = {key: value for key, value in
_name_estimators(self.clfs_)}
self.meta_clf_ = clone(self.meta_classifier)
self.named_meta_clf_ = {'meta-%s' % key: value for key, value in
_name_estimators([self.meta_clf_])}
if self.verbose > 0:
print("Fitting %d classifiers..." % (len(self.classifiers)))

Expand All @@ -144,8 +151,23 @@ def fit(self, X, y, groups=None):
final_cv.shuffle = self.shuffle
skf = list(final_cv.split(X, y, groups))

# Get fit_params for each classifier in self.named_clfs_
named_clfs_fit_params = {}
for name, clf in six.iteritems(self.named_clfs_):
clf_fit_params = {}
for key, value in six.iteritems(fit_params):
if name in key and 'meta-' not in key:
clf_fit_params[key.replace(name+'__', '')] = value
named_clfs_fit_params[name] = clf_fit_params
# Get fit_params for self.named_meta_clf_
meta_fit_params = {}
meta_clf_name = list(self.named_meta_clf_.keys())[0]
for key, value in six.iteritems(fit_params):
if meta_clf_name in key and 'meta-' in meta_clf_name:
meta_fit_params[key.replace(meta_clf_name+'__', '')] = value

all_model_predictions = np.array([]).reshape(len(y), 0)
for model in self.clfs_:
for name, model in six.iteritems(self.named_clfs_):

if self.verbose > 0:
i = self.clfs_.index(model) + 1
Expand All @@ -172,7 +194,8 @@ def fit(self, X, y, groups=None):
((num + 1), final_cv.get_n_splits()))

try:
model.fit(X[train_index], y[train_index])
model.fit(X[train_index], y[train_index],
**named_clfs_fit_params[name])
except TypeError as e:
raise TypeError(str(e) + '\nPlease check that X and y'
'are NumPy arrays. If X and y are lists'
Expand Down Expand Up @@ -215,16 +238,17 @@ def fit(self, X, y, groups=None):
X[test_index]))

# Fit the base models correctly this time using ALL the training set
for model in self.clfs_:
model.fit(X, y)
for name, model in six.iteritems(self.named_clfs_):
model.fit(X, y, **named_clfs_fit_params[name])

# Fit the secondary model
if not self.use_features_in_secondary:
self.meta_clf_.fit(all_model_predictions, reordered_labels)
self.meta_clf_.fit(all_model_predictions, reordered_labels,
**meta_fit_params)
else:
self.meta_clf_.fit(np.hstack((reordered_features,
all_model_predictions)),
reordered_labels)
reordered_labels, **meta_fit_params)

return self

Expand Down
25 changes: 24 additions & 1 deletion mlxtend/classifier/tests/test_stacking_cv_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mlxtend.classifier import StackingCVClassifier

import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
Expand Down Expand Up @@ -61,6 +61,29 @@ def test_StackingClassifier_proba():
assert scores_mean == 0.93


def test_StackingClassifier_fit_params():
np.random.seed(123)
meta = LogisticRegression()
clf1 = RandomForestClassifier()
clf2 = SGDClassifier(random_state=2)
sclf = StackingCVClassifier(classifiers=[clf1, clf2],
meta_classifier=meta,
shuffle=False)
fit_params = {
'sgdclassifier__intercept_init': np.unique(y),
'meta-logisticregression__sample_weight': np.full(X.shape[0], 2)
}

scores = cross_val_score(sclf,
X,
y,
cv=5,
scoring='accuracy',
fit_params=fit_params)
scores_mean = (round(scores.mean(), 2))
assert scores_mean == 0.86


def test_gridsearch():
np.random.seed(123)
meta = LogisticRegression()
Expand Down

0 comments on commit 7a67ce5

Please sign in to comment.