diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index a01eeef09fbb..c40dea5fd9b2 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -88,6 +88,18 @@ def is_cudf_available() -> bool: return False +def is_cupy_available() -> bool: + """Check cupy package available or not""" + if importlib.util.find_spec("cupy") is None: + return False + try: + import cupy + + return True + except ImportError: + return False + + try: import scipy.sparse as scipy_sparse from scipy.sparse import csr_matrix as scipy_csr diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index d6667ad89713..6b1d2faaacd1 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -59,7 +59,7 @@ import xgboost from xgboost import XGBClassifier -from xgboost.compat import is_cudf_available +from xgboost.compat import is_cudf_available, is_cupy_available from xgboost.core import Booster, _check_distributed_params from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm from xgboost.training import train as worker_train @@ -242,6 +242,13 @@ class _SparkXGBParams( TypeConverters.toList, ) + def set_device(self, value: str) -> "_SparkXGBParams": + """Set device, optional value: cpu, cuda, gpu""" + _check_distributed_params({"device": value}) + assert value in ("cpu", "cuda", "gpu") + self.set(self.device, value) + return self + @classmethod def _xgb_cls(cls) -> Type[XGBModel]: """ @@ -1193,6 +1200,31 @@ def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame: dataset = dataset.drop(pred_struct_col) return dataset + def _gpu_transform(self) -> bool: + """If gpu is used to do the prediction, true to gpu prediction""" + + if _is_local(_get_spark_session().sparkContext): + # if it's local model, we just use the internal "device" + return use_cuda(self.getOrDefault(self.device)) + + gpu_per_task = ( + _get_spark_session() + .sparkContext.getConf() + .get("spark.task.resource.gpu.amount") + ) + + # User don't set gpu configurations, just use cpu + if gpu_per_task is None: + if use_cuda(self.getOrDefault(self.device)): + get_logger("XGBoost-PySpark").warning( + "Do the prediction on the CPUs since " + "no gpu configurations are set" + ) + return False + + # User already sets the gpu configurations, we just use the internal "device". + return use_cuda(self.getOrDefault(self.device)) + def _transform(self, dataset: DataFrame) -> DataFrame: # pylint: disable=too-many-statements, too-many-locals # Save xgb_sklearn_model and predict_params to be local variable @@ -1216,21 +1248,77 @@ def _transform(self, dataset: DataFrame) -> DataFrame: _, schema = self._out_schema() + is_local = _is_local(_get_spark_session().sparkContext) + run_on_gpu = self._gpu_transform() + @pandas_udf(schema) # type: ignore def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: assert xgb_sklearn_model is not None model = xgb_sklearn_model + + from pyspark import TaskContext + + context = TaskContext.get() + assert context is not None + + dev_ordinal = -1 + + if is_cudf_available(): + if is_local: + if run_on_gpu and is_cupy_available(): + import cupy as cp # pylint: disable=import-error + + total_gpus = cp.cuda.runtime.getDeviceCount() + if total_gpus > 0: + partition_id = context.partitionId() + # For transform local mode, default the dev_ordinal to + # (partition id) % gpus. + dev_ordinal = partition_id % total_gpus + elif run_on_gpu: + dev_ordinal = _get_gpu_id(context) + + if dev_ordinal >= 0: + device = "cuda:" + str(dev_ordinal) + get_logger("XGBoost-PySpark").info( + "Do the inference with device: %s", device + ) + model.set_params(device=device) + else: + get_logger("XGBoost-PySpark").info("Do the inference on the CPUs") + else: + msg = ( + "CUDF is unavailable, fallback the inference on the CPUs" + if run_on_gpu + else "Do the inference on the CPUs" + ) + get_logger("XGBoost-PySpark").info(msg) + + def to_gpu_if_possible(data: ArrayLike) -> ArrayLike: + """Move the data to gpu if possible""" + if dev_ordinal >= 0: + import cudf # pylint: disable=import-error + import cupy as cp # pylint: disable=import-error + + # We must set the device after import cudf, which will change the device id to 0 + # See https://github.com/rapidsai/cudf/issues/11386 + cp.cuda.runtime.setDevice(dev_ordinal) # pylint: disable=I1101 + df = cudf.DataFrame(data) + del data + return df + return data + for data in iterator: if enable_sparse_data_optim: X = _read_csr_matrix_from_unwrapped_spark_vec(data) else: if feature_col_names is not None: - X = data[feature_col_names] + tmp = data[feature_col_names] else: - X = stack_series(data[alias.data]) + tmp = stack_series(data[alias.data]) + X = to_gpu_if_possible(tmp) if has_base_margin: - base_margin = data[alias.margin].to_numpy() + base_margin = to_gpu_if_possible(data[alias.margin]) else: base_margin = None diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 33a45a90ef91..66d7ca4548ca 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Optional, Set, Type import pyspark -from pyspark import BarrierTaskContext, SparkContext, SparkFiles +from pyspark import BarrierTaskContext, SparkContext, SparkFiles, TaskContext from pyspark.sql.session import SparkSession from xgboost import Booster, XGBModel, collective @@ -129,7 +129,7 @@ def _is_local(spark_context: SparkContext) -> bool: return spark_context._jsc.sc().isLocal() -def _get_gpu_id(task_context: BarrierTaskContext) -> int: +def _get_gpu_id(task_context: TaskContext) -> int: """Get the gpu id from the task resources""" if task_context is None: # This is a safety check. diff --git a/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py b/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py index a954d9d6c253..513554e430c7 100644 --- a/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py +++ b/tests/test_distributed/test_gpu_with_spark/test_gpu_spark.py @@ -2,6 +2,7 @@ import logging import subprocess +import numpy as np import pytest import sklearn @@ -13,7 +14,7 @@ from pyspark.ml.tuning import CrossValidator, ParamGridBuilder from pyspark.sql import SparkSession -from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor +from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor, SparkXGBRegressorModel gpu_discovery_script_path = "tests/test_distributed/test_gpu_with_spark/discover_gpu.sh" @@ -242,3 +243,33 @@ def test_sparkxgb_regressor_feature_cols_with_gpu(spark_diabetes_dataset_feature evaluator = RegressionEvaluator(metricName="rmse") rmse = evaluator.evaluate(pred_result_df) assert rmse <= 65.0 + + +def test_gpu_transform(spark_diabetes_dataset) -> None: + regressor = SparkXGBRegressor(device="cuda", num_workers=num_workers) + train_df, test_df = spark_diabetes_dataset + model: SparkXGBRegressorModel = regressor.fit(train_df) + + # The model trained with GPUs, and transform with GPU configurations. + assert model._gpu_transform() + + model.set_device("cpu") + assert not model._gpu_transform() + # without error + cpu_rows = model.transform(test_df).select("prediction").collect() + + regressor = SparkXGBRegressor(device="cpu", num_workers=num_workers) + model = regressor.fit(train_df) + + # The model trained with CPUs. Even with GPU configurations, + # still prefer transforming with CPUs + assert not model._gpu_transform() + + # Set gpu transform explicitly. + model.set_device("cuda") + assert model._gpu_transform() + # without error + gpu_rows = model.transform(test_df).select("prediction").collect() + + for cpu, gpu in zip(cpu_rows, gpu_rows): + np.testing.assert_allclose(cpu.prediction, gpu.prediction, atol=1e-3) diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index e323a36066cf..861e67a75331 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -888,6 +888,34 @@ def test_device_param(self, reg_data: RegData, clf_data: ClfData) -> None: clf = SparkXGBClassifier(device="cuda") clf._validate_params() + def test_gpu_transform(self, clf_data: ClfData) -> None: + """local mode""" + classifier = SparkXGBClassifier(device="cpu") + model: SparkXGBClassifierModel = classifier.fit(clf_data.cls_df_train) + + with tempfile.TemporaryDirectory() as tmpdir: + path = "file:" + tmpdir + model.write().overwrite().save(path) + + # The model trained with CPU, transform defaults to cpu + assert not model._gpu_transform() + + # without error + model.transform(clf_data.cls_df_test).collect() + + model.set_device("cuda") + assert model._gpu_transform() + + model_loaded = SparkXGBClassifierModel.load(path) + + # The model trained with CPU, transform defaults to cpu + assert not model_loaded._gpu_transform() + # without error + model_loaded.transform(clf_data.cls_df_test).collect() + + model_loaded.set_device("cuda") + assert model_loaded._gpu_transform() + class XgboostLocalTest(SparkTestCase): def setUp(self):