Skip to content

Commit

Permalink
[breaking] Remove deprecated parameters in the skl interface. (#9986)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jan 15, 2024
1 parent 2de85d3 commit 0798e36
Show file tree
Hide file tree
Showing 16 changed files with 418 additions and 462 deletions.
31 changes: 18 additions & 13 deletions demo/guide-python/continuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def training_continuation(tmpdir: str, use_pickle: bool) -> None:
"""Basic training continuation."""
# Train 128 iterations in 1 session
X, y = load_breast_cancer(return_X_y=True)
clf = xgboost.XGBClassifier(n_estimators=128)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")
clf = xgboost.XGBClassifier(n_estimators=128, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)])
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())

# Train 128 iterations in 2 sessions, with the first one runs for 32 iterations and
# the second one runs for 96 iterations
clf = xgboost.XGBClassifier(n_estimators=32)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss")
clf = xgboost.XGBClassifier(n_estimators=32, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)])
assert clf.get_booster().num_boosted_rounds() == 32

# load back the model, this could be a checkpoint
Expand All @@ -39,8 +39,8 @@ def training_continuation(tmpdir: str, use_pickle: bool) -> None:
loaded = xgboost.XGBClassifier()
loaded.load_model(path)

clf = xgboost.XGBClassifier(n_estimators=128 - 32)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", xgb_model=loaded)
clf = xgboost.XGBClassifier(n_estimators=128 - 32, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)], xgb_model=loaded)

print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())

Expand All @@ -56,19 +56,24 @@ def training_continuation_early_stop(tmpdir: str, use_pickle: bool) -> None:
n_estimators = 512

X, y = load_breast_cancer(return_X_y=True)
clf = xgboost.XGBClassifier(n_estimators=n_estimators)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", callbacks=[early_stop])
clf = xgboost.XGBClassifier(
n_estimators=n_estimators, eval_metric="logloss", callbacks=[early_stop]
)
clf.fit(X, y, eval_set=[(X, y)])
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
best = clf.best_iteration

# Train 512 iterations in 2 sessions, with the first one runs for 128 iterations and
# the second one runs until early stop.
clf = xgboost.XGBClassifier(n_estimators=128)
clf = xgboost.XGBClassifier(
n_estimators=128, eval_metric="logloss", callbacks=[early_stop]
)
# Reinitialize the early stop callback
early_stop = xgboost.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
)
clf.fit(X, y, eval_set=[(X, y)], eval_metric="logloss", callbacks=[early_stop])
clf.set_params(callbacks=[early_stop])
clf.fit(X, y, eval_set=[(X, y)])
assert clf.get_booster().num_boosted_rounds() == 128

# load back the model, this could be a checkpoint
Expand All @@ -87,13 +92,13 @@ def training_continuation_early_stop(tmpdir: str, use_pickle: bool) -> None:
early_stop = xgboost.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
)
clf = xgboost.XGBClassifier(n_estimators=n_estimators - 128)
clf = xgboost.XGBClassifier(
n_estimators=n_estimators - 128, eval_metric="logloss", callbacks=[early_stop]
)
clf.fit(
X,
y,
eval_set=[(X, y)],
eval_metric="logloss",
callbacks=[early_stop],
xgb_model=loaded,
)

Expand Down
35 changes: 20 additions & 15 deletions demo/guide-python/sklearn_evals_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,35 @@
X_train, X_test = X[:1600], X[1600:]
y_train, y_test = y[:1600], y[1600:]

param_dist = {'objective':'binary:logistic', 'n_estimators':2}
param_dist = {"objective": "binary:logistic", "n_estimators": 2}

clf = xgb.XGBModel(**param_dist)
clf = xgb.XGBModel(
**param_dist,
eval_metric="logloss",
)
# Or you can use: clf = xgb.XGBClassifier(**param_dist)

clf.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
eval_metric='logloss',
verbose=True)
clf.fit(
X_train,
y_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
verbose=True,
)

# Load evals result by calling the evals_result() function
evals_result = clf.evals_result()

print('Access logloss metric directly from validation_0:')
print(evals_result['validation_0']['logloss'])
print("Access logloss metric directly from validation_0:")
print(evals_result["validation_0"]["logloss"])

print('')
print('Access metrics through a loop:')
print("")
print("Access metrics through a loop:")
for e_name, e_mtrs in evals_result.items():
print('- {}'.format(e_name))
print("- {}".format(e_name))
for e_mtr_name, e_mtr_vals in e_mtrs.items():
print(' - {}'.format(e_mtr_name))
print(' - {}'.format(e_mtr_vals))
print(" - {}".format(e_mtr_name))
print(" - {}".format(e_mtr_vals))

print('')
print('Access complete dict:')
print("")
print("Access complete dict:")
print(evals_result)
31 changes: 17 additions & 14 deletions demo/guide-python/sklearn_examples.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'''
"""
Collection of examples for using sklearn interface
==================================================
Expand All @@ -8,7 +8,7 @@
Created on 1 Apr 2015
@author: Jamie Hall
'''
"""
import pickle

import numpy as np
Expand All @@ -22,8 +22,8 @@

print("Zeros and Ones from the Digits dataset: binary classification")
digits = load_digits(n_class=2)
y = digits['target']
X = digits['data']
y = digits["target"]
X = digits["data"]
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X):
xgb_model = xgb.XGBClassifier(n_jobs=1).fit(X[train_index], y[train_index])
Expand All @@ -33,8 +33,8 @@

print("Iris: multiclass classification")
iris = load_iris()
y = iris['target']
X = iris['data']
y = iris["target"]
X = iris["data"]
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X):
xgb_model = xgb.XGBClassifier(n_jobs=1).fit(X[train_index], y[train_index])
Expand All @@ -53,9 +53,13 @@

print("Parameter optimization")
xgb_model = xgb.XGBRegressor(n_jobs=1)
clf = GridSearchCV(xgb_model,
{'max_depth': [2, 4],
'n_estimators': [50, 100]}, verbose=1, n_jobs=1, cv=3)
clf = GridSearchCV(
xgb_model,
{"max_depth": [2, 4], "n_estimators": [50, 100]},
verbose=1,
n_jobs=1,
cv=3,
)
clf.fit(X, y)
print(clf.best_score_)
print(clf.best_params_)
Expand All @@ -69,9 +73,8 @@

# Early-stopping

X = digits['data']
y = digits['target']
X = digits["data"]
y = digits["target"]
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = xgb.XGBClassifier(n_jobs=1)
clf.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc",
eval_set=[(X_test, y_test)])
clf = xgb.XGBClassifier(n_jobs=1, early_stopping_rounds=10, eval_metric="auc")
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
1 change: 1 addition & 0 deletions demo/guide-python/sklearn_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
if __name__ == "__main__":
print("Parallel Parameter optimization")
X, y = fetch_california_housing(return_X_y=True)
# Make sure the number of threads is balanced.
xgb_model = xgb.XGBRegressor(
n_jobs=multiprocessing.cpu_count() // 2, tree_method="hist"
)
Expand Down
12 changes: 6 additions & 6 deletions doc/tutorials/custom_metric_obj.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ monitor our model's performance. As mentioned above, the default metric for ``S
elements = np.power(np.log1p(y) - np.log1p(predt), 2)
return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))
Since we are demonstrating in Python, the metric or objective need not be a function,
any callable object should suffice. Similar to the objective function, our metric also
accepts ``predt`` and ``dtrain`` as inputs, but returns the name of the metric itself and a
floating point value as the result. After passing it into XGBoost as argument of ``feval``
parameter:
Since we are demonstrating in Python, the metric or objective need not be a function, any
callable object should suffice. Similar to the objective function, our metric also
accepts ``predt`` and ``dtrain`` as inputs, but returns the name of the metric itself and
a floating point value as the result. After passing it into XGBoost as argument of
``custom_metric`` parameter:

.. code-block:: python
Expand All @@ -136,7 +136,7 @@ parameter:
dtrain=dtrain,
num_boost_round=10,
obj=squared_log,
feval=rmsle,
custom_metric=rmsle,
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
evals_result=results)
Expand Down
Loading

0 comments on commit 0798e36

Please sign in to comment.