Skip to content

Commit

Permalink
Fix spark.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 14, 2024
1 parent 383342e commit 560adda
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions tests/test_distributed/test_with_spark/test_spark_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,24 +311,20 @@ def clf_with_weight(
y_val = np.array([0, 1])
w_train = np.array([1.0, 2.0])
w_val = np.array([1.0, 2.0])
cls2 = XGBClassifier()
cls2 = XGBClassifier(eval_metric="logloss", early_stopping_rounds=1)
cls2.fit(
X_train,
y_train,
eval_set=[(X_val, y_val)],
early_stopping_rounds=1,
eval_metric="logloss",
)

cls3 = XGBClassifier()
cls3 = XGBClassifier(eval_metric="logloss", early_stopping_rounds=1)
cls3.fit(
X_train,
y_train,
sample_weight=w_train,
eval_set=[(X_val, y_val)],
sample_weight_eval_set=[w_val],
early_stopping_rounds=1,
eval_metric="logloss",
)

cls_df_train_with_eval_weight = spark.createDataFrame(
Expand Down

0 comments on commit 560adda

Please sign in to comment.