Skip to content

Commit

Permalink
TST create instances from exotic estimators for docstring params check (
Browse files Browse the repository at this point in the history
scikit-learn#20243)

Co-authored-by: Olivier Grisel <[email protected]>
  • Loading branch information
glemaitre and ogrisel authored Jun 11, 2021
1 parent 28bc843 commit 95afec8
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 40 deletions.
3 changes: 3 additions & 0 deletions sklearn/cluster/_bicluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ class SpectralBiclustering(BaseSpectral):
column_labels_ : array-like of shape (n_cols,)
Column partition labels.
biclusters_ : tuple of two ndarrays
The tuple contains the `rows_` and `columns_` arrays.
n_features_in_ : int
Number of features seen during :term:`fit`.
Expand Down
8 changes: 8 additions & 0 deletions sklearn/decomposition/_dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,14 @@ class SparseCoder(_BaseSparseCoding, BaseEstimator):
This attribute is deprecated in 0.24 and will be removed in
1.1 (renaming of 0.26). Use `dictionary` instead.
n_components_ : int
Number of atoms.
n_features_in_ : int
Number of features seen during :term:`fit`.
.. versionadded:: 0.24
Examples
--------
>>> import numpy as np
Expand Down
14 changes: 14 additions & 0 deletions sklearn/ensemble/_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,12 @@ class StackingClassifier(ClassifierMixin, _BaseStacking):
named_estimators_ : :class:`~sklearn.utils.Bunch`
Attribute to access any fitted sub-estimators by name.
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying classifier exposes such an attribute when fit.
.. versionadded:: 0.24
final_estimator_ : estimator
The classifier which predicts given the output of `estimators_`.
Expand Down Expand Up @@ -611,10 +617,18 @@ class StackingRegressor(RegressorMixin, _BaseStacking):
named_estimators_ : :class:`~sklearn.utils.Bunch`
Attribute to access any fitted sub-estimators by name.
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying regressor exposes such an attribute when fit.
.. versionadded:: 0.24
final_estimator_ : estimator
The regressor to stacked the base estimators fitted.
stack_method_ : list of str
The method used by each base estimator.
References
----------
.. [1] Wolpert, David H. "Stacked generalization." Neural networks 5.2
Expand Down
18 changes: 17 additions & 1 deletion sklearn/ensemble/_voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,19 @@ class VotingClassifier(ClassifierMixin, _BaseVoting):
.. versionadded:: 0.20
classes_ : array-like of shape (n_predictions,)
le_ : :class:`~sklearn.preprocessing.LabelEncoder`
Transformer used to encode the labels during fit and decode during
prediction.
classes_ : ndarray of shape (n_classes,)
The classes labels.
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying classifier exposes such an attribute when fit.
.. versionadded:: 0.24
See Also
--------
VotingRegressor : Prediction voting regressor.
Expand Down Expand Up @@ -431,6 +441,12 @@ class VotingRegressor(RegressorMixin, _BaseVoting):
.. versionadded:: 0.20
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying regressor exposes such an attribute when fit.
.. versionadded:: 0.24
See Also
--------
VotingClassifier : Soft Voting/Majority Rule classifier.
Expand Down
8 changes: 4 additions & 4 deletions sklearn/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,9 +949,9 @@ class CountVectorizer(_VectorizerMixin, BaseEstimator):
vocabulary_ : dict
A mapping of terms to feature indices.
fixed_vocabulary_: boolean
fixed_vocabulary_ : bool
True if a fixed vocabulary of term to indices mapping
is provided by the user
is provided by the user.
stop_words_ : set
Terms that were ignored because they either:
Expand Down Expand Up @@ -1684,9 +1684,9 @@ class TfidfVectorizer(CountVectorizer):
vocabulary_ : dict
A mapping of terms to feature indices.
fixed_vocabulary_: bool
fixed_vocabulary_ : bool
True if a fixed vocabulary of term to indices mapping
is provided by the user
is provided by the user.
idf_ : array of shape (n_features,)
The inverse document frequency (IDF) vector; only defined
Expand Down
6 changes: 6 additions & 0 deletions sklearn/feature_selection/_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
This is stored only when a non-fitted estimator is passed to the
``SelectFromModel``, i.e when prefit is False.
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying estimator exposes such an attribute when fit.
.. versionadded:: 0.24
threshold_ : float
The threshold value used for feature selection.
Expand Down
18 changes: 18 additions & 0 deletions sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,21 @@ class RFE(SelectorMixin, MetaEstimatorMixin, BaseEstimator):
Attributes
----------
classes_ : ndarray of shape (n_classes,)
The classes labels. Only available when `estimator` is a classifier.
estimator_ : ``Estimator`` instance
The fitted estimator used to select features.
n_features_ : int
The number of selected features.
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying estimator exposes such an attribute when fit.
.. versionadded:: 0.24
ranking_ : ndarray of shape (n_features,)
The feature ranking, such that ``ranking_[i]`` corresponds to the
ranking position of the i-th feature. Selected (i.e., estimated
Expand Down Expand Up @@ -464,6 +473,9 @@ class RFECV(RFE):
Attributes
----------
classes_ : ndarray of shape (n_classes,)
The classes labels. Only available when `estimator` is a classifier.
estimator_ : ``Estimator`` instance
The fitted estimator used to select features.
Expand All @@ -475,6 +487,12 @@ class RFECV(RFE):
n_features_ : int
The number of selected features with cross-validation.
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying estimator exposes such an attribute when fit.
.. versionadded:: 0.24
ranking_ : narray of shape (n_features,)
The feature ranking, such that `ranking_[i]`
corresponds to the ranking
Expand Down
6 changes: 6 additions & 0 deletions sklearn/feature_selection/_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ class SequentialFeatureSelector(SelectorMixin, MetaEstimatorMixin,
Attributes
----------
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying estimator exposes such an attribute when fit.
.. versionadded:: 0.24
n_features_to_select_ : int
The number of features that were selected.
Expand Down
4 changes: 4 additions & 0 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class Pipeline(_BaseComposition):
Read-only attribute to access any step parameter by user given name.
Keys are step names and values are steps parameters.
classes_ : ndarray of shape (n_classes,)
The classes labels. Only exist if the last step of the pipeline is a
classifier.
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying first estimator in `steps` exposes such an attribute
Expand Down
10 changes: 10 additions & 0 deletions sklearn/random_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,11 @@ class GaussianRandomProjection(BaseRandomProjection):
components_ : ndarray of shape (n_components, n_features)
Random matrix used for the projection.
n_features_in_ : int
Number of features seen during :term:`fit`.
.. versionadded:: 0.24
Examples
--------
>>> import numpy as np
Expand Down Expand Up @@ -586,6 +591,11 @@ class SparseRandomProjection(BaseRandomProjection):
density_ : float in range 0.0 - 1.0
Concrete density computed from when density = "auto".
n_features_in_ : int
Number of features seen during :term:`fit`.
.. versionadded:: 0.24
Examples
--------
>>> import numpy as np
Expand Down
109 changes: 74 additions & 35 deletions sklearn/tests/test_docstring_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sklearn.externals._pep562 import Pep562
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import FunctionTransformer

import pytest

Expand Down Expand Up @@ -175,6 +176,27 @@ def _construct_searchcv_instance(SearchCV):
return SearchCV(LogisticRegression(), {"C": [0.1, 1]})


def _construct_compose_pipeline_instance(Estimator):
# Minimal / degenerate instances: only useful to test the docstrings.
if Estimator.__name__ == "ColumnTransformer":
return Estimator(transformers=[("transformer", "passthrough", [0, 1])])
elif Estimator.__name__ == "Pipeline":
return Estimator(steps=[("clf", LogisticRegression())])
elif Estimator.__name__ == "FeatureUnion":
return Estimator(transformer_list=[
("transformer", FunctionTransformer())
])


def _construct_sparse_coder(Estimator):
# XXX: hard-coded assumption that n_features=3
dictionary = np.array(
[[0, 1, 0], [-1, -1, 2], [1, 1, 1], [0, 1, 1], [0, 2, 1]],
dtype=np.float64,
)
return Estimator(dictionary=dictionary)


N_FEATURES_MODULES_TO_IGNORE = {
'model_selection',
'multioutput',
Expand All @@ -190,56 +212,76 @@ def test_fit_docstring_attributes(name, Estimator):
doc = docscrape.ClassDoc(Estimator)
attributes = doc['Attributes']

IGNORED = {
'ClassifierChain',
'CountVectorizer', 'DictVectorizer',
'GaussianRandomProjection',
'MultiOutputClassifier', 'MultiOutputRegressor',
'NoSampleWeightWrapper', 'RFE', 'RFECV',
'RegressorChain', 'SelectFromModel',
'SparseCoder', 'SparseRandomProjection',
'SpectralBiclustering', 'StackingClassifier',
'StackingRegressor', 'TfidfVectorizer', 'VotingClassifier',
'VotingRegressor', 'SequentialFeatureSelector',
}

if Estimator.__name__ in IGNORED or Estimator.__name__.startswith('_'):
pytest.skip("Estimator cannot be fit easily to test fit attributes")

if Estimator.__name__ in (
"HalvingRandomSearchCV",
"RandomizedSearchCV",
"HalvingGridSearchCV",
"GridSearchCV",
):
est = _construct_searchcv_instance(Estimator)
elif Estimator.__name__ in (
"ColumnTransformer",
"Pipeline",
"FeatureUnion",
):
est = _construct_compose_pipeline_instance(Estimator)
elif Estimator.__name__ == "SparseCoder":
est = _construct_sparse_coder(Estimator)
else:
est = _construct_instance(Estimator)

if Estimator.__name__ == 'SelectKBest':
est.k = 2

if Estimator.__name__ == 'DummyClassifier':
est.strategy = "stratified"

if 'PLS' in Estimator.__name__ or 'CCA' in Estimator.__name__:
est.n_components = 1 # default = 2 is invalid for single target.
est.set_params(k=2)
elif Estimator.__name__ == 'DummyClassifier':
est.set_params(strategy="stratified")
elif Estimator.__name__ == 'CCA' or Estimator.__name__.startswith('PLS'):
# default = 2 is invalid for single target
est.set_params(n_components=1)
elif Estimator.__name__ in (
"GaussianRandomProjection",
"SparseRandomProjection",
):
# default="auto" raises an error with the shape of `X`
est.set_params(n_components=2)

# FIXME: TO BE REMOVED for 1.1 (avoid FutureWarning)
if Estimator.__name__ == 'NMF':
est.init = 'nndsvda'
est.set_params(init='nndsvda')

# FIXME: TO BE REMOVED for 1.2 (avoid FutureWarning)
if Estimator.__name__ == 'TSNE':
est.learning_rate = 200.0
est.init = 'random'

X, y = make_classification(n_samples=20, n_features=3,
n_redundant=0, n_classes=2,
random_state=2)
est.set_params(learning_rate=200.0, init='random')

# For PLS, TODO remove in 1.1
skipped_attributes = {"x_scores_", "y_scores_"}

if Estimator.__name__.endswith("Vectorizer"):
# Vectorizer require some specific input data
if Estimator.__name__ in (
"CountVectorizer",
"HashingVectorizer",
"TfidfVectorizer",
):
X = [
"This is the first document.",
"This document is the second document.",
"And this is the third one.",
"Is this the first document?",
]
elif Estimator.__name__ == "DictVectorizer":
X = [{"foo": 1, "bar": 2}, {"foo": 3, "baz": 1}]
y = None
else:
X, y = make_classification(
n_samples=20,
n_features=3,
n_redundant=0,
n_classes=2,
random_state=2,
)

y = _enforce_estimator_tags_y(est, y)
X = _enforce_estimator_tags_x(est, X)
y = _enforce_estimator_tags_y(est, y)
X = _enforce_estimator_tags_x(est, X)

if '1dlabels' in est._get_tags()['X_types']:
est.fit(y)
Expand All @@ -248,9 +290,6 @@ def test_fit_docstring_attributes(name, Estimator):
else:
est.fit(X, y)

skipped_attributes = {'x_scores_', # For PLS, TODO remove in 1.1
'y_scores_'} # For PLS, TODO remove in 1.1

module = est.__module__.split(".")[1]
if module in N_FEATURES_MODULES_TO_IGNORE:
skipped_attributes.add("n_features_in_")
Expand Down

0 comments on commit 95afec8

Please sign in to comment.