Skip to content

Commit

Permalink
add tests for stump
Browse files Browse the repository at this point in the history
  • Loading branch information
thatlittleboy committed Jul 9, 2023
1 parent c7e612a commit aa312fd
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3597,24 +3597,39 @@ def test_reset_params_works_with_metric_num_class_and_boosting():
assert new_bst.params == expected_params


def test_dump_model():
@pytest.mark.parametrize("stump", [True, False])
def test_dump_model(stump):
X, y = load_breast_cancer(return_X_y=True)
if stump:
# intentionally create a stump (tree with only a root-node)
# using restricted # samples
subidx = random.sample(range(len(y)), 30)
X = X[subidx]
y = y[subidx]

train_data = lgb.Dataset(X, label=y)
params = {
"objective": "binary",
"verbose": -1
"verbose": -1,
}
bst = lgb.train(params, train_data, num_boost_round=5)
dumped_model_str = str(bst.dump_model(5, 0))
dumped_model = bst.dump_model(5, 0)
if stump:
assert len(dumped_model["tree_structure"]) == 1
dumped_model_str = str(dumped_model)
assert "leaf_features" not in dumped_model_str
assert "leaf_coeff" not in dumped_model_str
assert "leaf_const" not in dumped_model_str
assert "leaf_value" in dumped_model_str
assert "leaf_count" in dumped_model_str

params['linear_tree'] = True
train_data = lgb.Dataset(X, label=y)
bst = lgb.train(params, train_data, num_boost_round=5)
dumped_model_str = str(bst.dump_model(5, 0))
dumped_model = bst.dump_model(5, 0)
if stump:
assert len(dumped_model["tree_structure"]) == 1
dumped_model_str = str(dumped_model)
assert "leaf_features" in dumped_model_str
assert "leaf_coeff" in dumped_model_str
assert "leaf_const" in dumped_model_str
Expand Down

0 comments on commit aa312fd

Please sign in to comment.