Skip to content

Commit

Permalink
[python] add return_cvbooster flag to cv func and publish _CVBooster (#…
Browse files Browse the repository at this point in the history
…283,#2105,#1445) (#3204)

* [python] add return_cvbooster flag to cv function and rename _CVBooster to make public (#283,#2105)

* [python] Reduce expected metric of unit testing

* [docs] add the CVBooster to the documentation

* [python] reflect the review comments

- Add some clarifications to the documentation
- Rename CVBooster.append to make private
- Decrease iteration rounds of testing to save CI time
- Use CVBooster as root member of lgb

* [python] add more checks in testing for cv

Co-authored-by: Nikita Titov <[email protected]>

* [python] add docstring for instance attributes of CVBooster

Co-authored-by: Nikita Titov <[email protected]>

* [python] fix docstring

Co-authored-by: Nikita Titov <[email protected]>

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
momijiame and StrikerRUS authored Aug 2, 2020
1 parent 66600b2 commit 1d59a04
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/Python-API.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Data Structure API

Dataset
Booster
CVBooster

Training API
------------
Expand Down
4 changes: 2 additions & 2 deletions python-package/lightgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .basic import Booster, Dataset
from .callback import (early_stopping, print_evaluation, record_evaluation,
reset_parameter)
from .engine import cv, train
from .engine import cv, train, CVBooster

import os

Expand All @@ -29,7 +29,7 @@
with open(os.path.join(dir_path, 'VERSION.txt')) as version_file:
__version__ = version_file.read().strip()

__all__ = ['Dataset', 'Booster',
__all__ = ['Dataset', 'Booster', 'CVBooster',
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
Expand Down
40 changes: 32 additions & 8 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,19 +276,35 @@ def train(params, train_set, num_boost_round=100,
return booster


class _CVBooster(object):
"""Auxiliary data struct to hold all boosters of CV."""
class CVBooster(object):
"""CVBooster in LightGBM.
Auxiliary data structure to hold and redirect all boosters of ``cv`` function.
This class has the same methods as Booster class.
All method calls are actually performed for underlying Boosters and then all returned results are returned in a list.
Attributes
----------
boosters : list of Booster
The list of underlying fitted models.
best_iteration : int
The best iteration of fitted model.
"""

def __init__(self):
"""Initialize the CVBooster.
Generally, no need to instantiate manually.
"""
self.boosters = []
self.best_iteration = -1

def append(self, booster):
"""Add a booster to _CVBooster."""
def _append(self, booster):
"""Add a booster to CVBooster."""
self.boosters.append(booster)

def __getattr__(self, name):
"""Redirect methods call of _CVBooster."""
"""Redirect methods call of CVBooster."""
def handler_function(*args, **kwargs):
"""Call methods with each booster, and concatenate their results."""
ret = []
Expand Down Expand Up @@ -341,7 +357,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
train_id = [np.concatenate([test_id[i] for i in range_(nfold) if k != i]) for k in range_(nfold)]
folds = zip_(train_id, test_id)

ret = _CVBooster()
ret = CVBooster()
for train_idx, test_idx in folds:
train_set = full_data.subset(sorted(train_idx))
valid_set = full_data.subset(sorted(test_idx))
Expand All @@ -354,7 +370,7 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
if eval_train_metric:
cvbooster.add_valid(train_set, 'train')
cvbooster.add_valid(valid_set, 'valid')
ret.append(cvbooster)
ret._append(cvbooster)
return ret


Expand All @@ -380,7 +396,8 @@ def cv(params, train_set, num_boost_round=100,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, fpreproc=None,
verbose_eval=None, show_stdv=True, seed=0,
callbacks=None, eval_train_metric=False):
callbacks=None, eval_train_metric=False,
return_cvbooster=False):
"""Perform the cross-validation with given paramaters.
Parameters
Expand Down Expand Up @@ -486,6 +503,8 @@ def cv(params, train_set, num_boost_round=100,
eval_train_metric : bool, optional (default=False)
Whether to display the train metric in progress.
The score of the metric is calculated again after each training step, so there is some impact on performance.
return_cvbooster : bool, optional (default=False)
Whether to return Booster models trained on each fold through ``CVBooster``.
Returns
-------
Expand All @@ -495,6 +514,7 @@ def cv(params, train_set, num_boost_round=100,
{'metric1-mean': [values], 'metric1-stdv': [values],
'metric2-mean': [values], 'metric2-stdv': [values],
...}.
If ``return_cvbooster=True``, also returns trained boosters via ``cvbooster`` key.
"""
if not isinstance(train_set, Dataset):
raise TypeError("Training only accepts Dataset object")
Expand Down Expand Up @@ -586,4 +606,8 @@ def cv(params, train_set, num_boost_round=100,
for k in results:
results[k] = results[k][:cvfolds.best_iteration]
break

if return_cvbooster:
results['cvbooster'] = cvfolds

return dict(results)
44 changes: 44 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,50 @@ def test_cv(self):
verbose_eval=False)
np.testing.assert_allclose(cv_res_lambda['ndcg@3-mean'], cv_res_lambda_obj['ndcg@3-mean'])

def test_cvbooster(self):
X, y = load_breast_cancer(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'verbose': -1,
}
lgb_train = lgb.Dataset(X_train, y_train)
# with early stopping
cv_res = lgb.cv(params, lgb_train,
num_boost_round=25,
early_stopping_rounds=5,
verbose_eval=False,
nfold=3,
return_cvbooster=True)
self.assertIn('cvbooster', cv_res)
cvb = cv_res['cvbooster']
self.assertIsInstance(cvb, lgb.CVBooster)
self.assertIsInstance(cvb.boosters, list)
self.assertEqual(len(cvb.boosters), 3)
self.assertTrue(all(isinstance(bst, lgb.Booster) for bst in cvb.boosters))
self.assertGreater(cvb.best_iteration, 0)
# predict by each fold booster
preds = cvb.predict(X_test, num_iteration=cvb.best_iteration)
self.assertIsInstance(preds, list)
self.assertEqual(len(preds), 3)
# fold averaging
avg_pred = np.mean(preds, axis=0)
ret = log_loss(y_test, avg_pred)
self.assertLess(ret, 0.13)
# without early stopping
cv_res = lgb.cv(params, lgb_train,
num_boost_round=20,
verbose_eval=False,
nfold=3,
return_cvbooster=True)
cvb = cv_res['cvbooster']
self.assertEqual(cvb.best_iteration, -1)
preds = cvb.predict(X_test)
avg_pred = np.mean(preds, axis=0)
ret = log_loss(y_test, avg_pred)
self.assertLess(ret, 0.15)

def test_feature_name(self):
X_train, y_train = load_boston(True)
params = {'verbose': -1}
Expand Down

0 comments on commit 1d59a04

Please sign in to comment.