You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Nov 14, 2023. It is now read-only.
From SLEP018 on, scikit-learn has released a global context setter with a simple set_config API. One of the use case is to propagate the transformed values through the pipeline as a Pandas DataFrame (with set_config(transform_output="pandas")).
I got this issue trying to run a HPO using TuneGridSearchCV which does not preserve the context set previously. A simple replication can be done using:
model.py
importsklearnfromsklearn.pipelineimportPipelinefromsklearn.linear_modelimportRidgefromsklearn.preprocessingimportStandardScalerfromsklearn.feature_selectionimportVarianceThreshold# This sets the global context.sklearn.set_config(transform_output="pandas")
classCustomVarianceThreshold(VarianceThreshold):
deffit(self, X, y=None):
assertX.columnsisnotNonereturnsuper().fit(X, y)
deftransform(self, X):
assertX.columnsisnotNonereturnsuper().transform(X)
defMODEL():
returnPipeline([
("scaler", StandardScaler()),
("selector", CustomVarianceThreshold()),
("regressor", Ridge()),
])
main.py
importnumpyasnpimportsysfromsklearn.datasetsimportload_irisfromsklearn.model_selectionimportGridSearchCVfromtune_sklearnimportTuneGridSearchCVfrommodelimportMODELPARAMS= {
"regressor__alpha": np.linspace(0, 1, 100)
}
# This works fine.defrun_single():
model=MODEL()
data=load_iris(as_frame=True)
model.fit(data.data, data.target)
model.predict(data.data)
# This also works fine. HPO is enabled, but Ray is not the backend.defrun_hpo():
model=MODEL()
cv=GridSearchCV(model, PARAMS, n_jobs=-1)
data=load_iris(as_frame=True)
cv.fit(data.data, data.target)
best_model=cv.best_estimator_best_model.predict(data.data)
# This breaks because `columns` is not an attribute of X.defrun_hpo_with_ray():
model=MODEL()
cv=TuneGridSearchCV(model, PARAMS, n_jobs=-1)
data=load_iris(as_frame=True)
cv.fit(data.data, data.target)
best_model=cv.best_estimator_best_model.predict(data.data)
Although there is a workaround which is to define model with extra method (replacing StandardScaler() with StandardScaler().set_output(transform="pandas")), I think it would be nice if a global context setting via set_config integrates well with tune_sklearn.
I have checked the issues but did not find any pre-existing issue/documentation. Please let me know if this is a duplicate and I apologize if that is the case.
The text was updated successfully, but these errors were encountered:
Sign up for freeto subscribe to this conversation on GitHub.
Already have an account?
Sign in.
From SLEP018 on,
scikit-learn
has released a global context setter with a simpleset_config
API. One of the use case is to propagate the transformed values through the pipeline as a Pandas DataFrame (withset_config(transform_output="pandas")
).I got this issue trying to run a HPO using
TuneGridSearchCV
which does not preserve the context set previously. A simple replication can be done using:model.py
main.py
Although there is a workaround which is to define model with extra method (replacing
StandardScaler()
withStandardScaler().set_output(transform="pandas")
), I think it would be nice if a global context setting viaset_config
integrates well withtune_sklearn
.I have checked the issues but did not find any pre-existing issue/documentation. Please let me know if this is a duplicate and I apologize if that is the case.
The text was updated successfully, but these errors were encountered: