Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 31, 2023
1 parent bfe35cb commit 1c2f562
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions tests/python-gpu/test_gpu_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,12 @@ def test_inplace_predict_device_type(self, device: str) -> None:
np.testing.assert_allclose(predt_0, predt_3)
np.testing.assert_allclose(predt_0, predt_4)

def run_inplace_base_margin(self, booster, dtrain, X, base_margin):
def run_inplace_base_margin(
self, device: int, booster: xgb.Booster, dtrain: xgb.DMatrix, X, base_margin
) -> None:
import cupy as cp

booster.set_param({"device": "cuda"})
booster.set_param({"device": f"cuda:{device}"})
dtrain.set_info(base_margin=base_margin)
from_inplace = booster.inplace_predict(data=X, base_margin=base_margin)
from_dmatrix = booster.predict(dtrain)
Expand All @@ -208,11 +210,12 @@ def run_inplace_base_margin(self, booster, dtrain, X, base_margin):

booster = booster.copy() # clear prediction cache.
base_margin = cp.asnumpy(base_margin)
X = cp.asnumpy(X)
booster.set_param({"device": "cuda"})
if hasattr(X, "values"):
X = cp.asnumpy(X.values)
booster.set_param({"device": f"cuda:{device}"})
from_inplace = booster.inplace_predict(data=X, base_margin=base_margin)
from_dmatrix = booster.predict(dtrain)
np.testing.assert_allclose(from_inplace, from_dmatrix, rtol=1e-6)
cp.testing.assert_allclose(from_inplace, from_dmatrix, rtol=1e-6)

def run_inplace_predict_cupy(self, device: int) -> None:
import cupy as cp
Expand All @@ -233,7 +236,7 @@ def run_inplace_predict_cupy(self, device: int) -> None:
dtrain = xgb.DMatrix(X, y)

booster = xgb.train(
{"tree_method": "hist", "device": f"cpu"},
{"tree_method": "hist", "device": f"cuda:{device}"},
dtrain,
num_boost_round=10,
)
Expand All @@ -259,7 +262,7 @@ def predict_dense(x):
run_threaded_predict(X, rows, predict_dense)

base_margin = cp_rng.randn(rows)
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
self.run_inplace_base_margin(device, booster, dtrain, X, base_margin)

# Create a wide dataset
X = cp_rng.randn(100, 10000)
Expand Down Expand Up @@ -333,7 +336,7 @@ def predict_df(x):
run_threaded_predict(X, rows, predict_df)

base_margin = cudf.Series(rng.randn(rows))
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
self.run_inplace_base_margin(0, booster, dtrain, X, base_margin)

@given(
strategies.integers(1, 10), tm.make_dataset_strategy(), shap_parameter_strategy
Expand Down

0 comments on commit 1c2f562

Please sign in to comment.