From 7a67ce57ae9e16e95eee15e9549517ffaf82dc7f Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Sat, 14 Oct 2017 12:58:25 -0500 Subject: [PATCH] Adds fit_params to StackingCVClassifier fit method --- .../classifier/stacking_cv_classification.py | 42 +++++++++++++++---- .../tests/test_stacking_cv_classifier.py | 25 ++++++++++- 2 files changed, 57 insertions(+), 10 deletions(-) diff --git a/mlxtend/classifier/stacking_cv_classification.py b/mlxtend/classifier/stacking_cv_classification.py index 330a672a0..5138488a8 100644 --- a/mlxtend/classifier/stacking_cv_classification.py +++ b/mlxtend/classifier/stacking_cv_classification.py @@ -111,7 +111,7 @@ def __init__(self, classifiers, meta_classifier, self.stratify = stratify self.shuffle = shuffle - def fit(self, X, y, groups=None): + def fit(self, X, y, groups=None, **fit_params): """ Fit ensemble classifers and the meta-classifier. Parameters @@ -119,13 +119,16 @@ def fit(self, X, y, groups=None): X : numpy array, shape = [n_samples, n_features] Training vectors, where n_samples is the number of samples and n_features is the number of features. - y : numpy array, shape = [n_samples] Target values. - groups : numpy array/None, shape = [n_samples] The group that each sample belongs to. This is used by specific folding strategies such as GroupKFold() + fit_params : dict, optional + Parameters to pass to the fit methods of `classifiers` and + `meta_classifier`. Note that only fit parameters for `classifiers` + that are the same for each cross-validation split are supported + (e.g. `sample_weight` is not currently supported). Returns ------- @@ -133,7 +136,11 @@ def fit(self, X, y, groups=None): """ self.clfs_ = [clone(clf) for clf in self.classifiers] + self.named_clfs_ = {key: value for key, value in + _name_estimators(self.clfs_)} self.meta_clf_ = clone(self.meta_classifier) + self.named_meta_clf_ = {'meta-%s' % key: value for key, value in + _name_estimators([self.meta_clf_])} if self.verbose > 0: print("Fitting %d classifiers..." % (len(self.classifiers))) @@ -144,8 +151,23 @@ def fit(self, X, y, groups=None): final_cv.shuffle = self.shuffle skf = list(final_cv.split(X, y, groups)) + # Get fit_params for each classifier in self.named_clfs_ + named_clfs_fit_params = {} + for name, clf in six.iteritems(self.named_clfs_): + clf_fit_params = {} + for key, value in six.iteritems(fit_params): + if name in key and 'meta-' not in key: + clf_fit_params[key.replace(name+'__', '')] = value + named_clfs_fit_params[name] = clf_fit_params + # Get fit_params for self.named_meta_clf_ + meta_fit_params = {} + meta_clf_name = list(self.named_meta_clf_.keys())[0] + for key, value in six.iteritems(fit_params): + if meta_clf_name in key and 'meta-' in meta_clf_name: + meta_fit_params[key.replace(meta_clf_name+'__', '')] = value + all_model_predictions = np.array([]).reshape(len(y), 0) - for model in self.clfs_: + for name, model in six.iteritems(self.named_clfs_): if self.verbose > 0: i = self.clfs_.index(model) + 1 @@ -172,7 +194,8 @@ def fit(self, X, y, groups=None): ((num + 1), final_cv.get_n_splits())) try: - model.fit(X[train_index], y[train_index]) + model.fit(X[train_index], y[train_index], + **named_clfs_fit_params[name]) except TypeError as e: raise TypeError(str(e) + '\nPlease check that X and y' 'are NumPy arrays. If X and y are lists' @@ -215,16 +238,17 @@ def fit(self, X, y, groups=None): X[test_index])) # Fit the base models correctly this time using ALL the training set - for model in self.clfs_: - model.fit(X, y) + for name, model in six.iteritems(self.named_clfs_): + model.fit(X, y, **named_clfs_fit_params[name]) # Fit the secondary model if not self.use_features_in_secondary: - self.meta_clf_.fit(all_model_predictions, reordered_labels) + self.meta_clf_.fit(all_model_predictions, reordered_labels, + **meta_fit_params) else: self.meta_clf_.fit(np.hstack((reordered_features, all_model_predictions)), - reordered_labels) + reordered_labels, **meta_fit_params) return self diff --git a/mlxtend/classifier/tests/test_stacking_cv_classifier.py b/mlxtend/classifier/tests/test_stacking_cv_classifier.py index a4f38acfb..e7e42af4d 100644 --- a/mlxtend/classifier/tests/test_stacking_cv_classifier.py +++ b/mlxtend/classifier/tests/test_stacking_cv_classifier.py @@ -8,7 +8,7 @@ from mlxtend.classifier import StackingCVClassifier import pandas as pd -from sklearn.linear_model import LogisticRegression +from sklearn.linear_model import LogisticRegression, SGDClassifier from sklearn.naive_bayes import GaussianNB from sklearn.ensemble import RandomForestClassifier from sklearn.neighbors import KNeighborsClassifier @@ -61,6 +61,29 @@ def test_StackingClassifier_proba(): assert scores_mean == 0.93 +def test_StackingClassifier_fit_params(): + np.random.seed(123) + meta = LogisticRegression() + clf1 = RandomForestClassifier() + clf2 = SGDClassifier(random_state=2) + sclf = StackingCVClassifier(classifiers=[clf1, clf2], + meta_classifier=meta, + shuffle=False) + fit_params = { + 'sgdclassifier__intercept_init': np.unique(y), + 'meta-logisticregression__sample_weight': np.full(X.shape[0], 2) + } + + scores = cross_val_score(sclf, + X, + y, + cv=5, + scoring='accuracy', + fit_params=fit_params) + scores_mean = (round(scores.mean(), 2)) + assert scores_mean == 0.86 + + def test_gridsearch(): np.random.seed(123) meta = LogisticRegression()