Skip to content

Commit

Permalink
Add test_predict_stump on dask
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Oct 9, 2024
1 parent 04886c0 commit 15fc3bf
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,15 @@ def test_init_score(task, output, cluster, rng):

model_factory = task_to_dask_factory[task]

params = {"n_estimators": 1, "num_leaves": 2, "time_out": 5}
params = {
"n_estimators": 1,
"num_leaves": 2,
"time_out": 5,
"seed": 708,
"deterministic": True,
"force_row_wise": True,
"num_thread": 1,
}
num_classes = 1
if task == "multiclass-classification":
num_classes = 3
Expand Down Expand Up @@ -1533,6 +1541,43 @@ def test_predict_with_raw_score(task, output, cluster):
assert_eq(raw_predictions, pred_proba_raw)


@pytest.mark.parametrize("output", data_output)
@pytest.mark.parametrize("use_init_score", [False, True])
def test_predict_stump(output, use_init_score, cluster, rng):
with Client(cluster) as client:
task = "binary-classification"
n_samples = 1_000
_, _, _, _, dX, dy, _, dg = _create_data(objective=task, n_samples=n_samples, output=output)

model_factory = task_to_dask_factory[task]

params = {"objective": "binary", "n_estimators": 5, "min_data_in_leaf": n_samples}

if not use_init_score:
init_scores = None
elif output.startswith("dataframe"):
init_scores = dy.map_partitions(lambda x: pd.DataFrame(rng.uniform(size=x.size)))
else:
init_scores = dy.map_blocks(lambda x: rng.uniform(size=x.size))

model = model_factory(client=client, **params)
model.fit(dX, dy, group=dg, init_score=init_scores)
preds_1 = model.predict(dX, raw_score=True, num_iteration=1).compute()
preds_all = model.predict(dX, raw_score=True).compute()

if use_init_score:
# if init_score was provided, a model of stumps should predict all 0s
all_zeroes = np.full_like(preds_1, fill_value=0.0)
assert_eq(preds_1, all_zeroes)
assert_eq(preds_all, all_zeroes)
else:
# if init_score was not provided, prediction for a model of stumps should be
# the "average" of the labels
y_avg = np.log(dy.mean() / (1.0 - dy.mean()))
assert_eq(preds_1, np.full_like(preds_1, fill_value=y_avg))
assert_eq(preds_all, np.full_like(preds_all, fill_value=y_avg))


def test_distributed_quantized_training(tmp_path, cluster):
with Client(cluster) as client:
X, y, w, _, dX, dy, dw, _ = _create_data(objective="regression", output="array")
Expand Down

0 comments on commit 15fc3bf

Please sign in to comment.