Skip to content

Commit

Permalink
[backport][pyspark] Support stage-level scheduling (dmlc#9519) (dmlc#…
Browse files Browse the repository at this point in the history
…9686)

Co-authored-by: Bobby Wang <[email protected]>
  • Loading branch information
2 people authored and razdoburdin committed Oct 26, 2023
1 parent 6768332 commit c5f73d3
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 51 deletions.
220 changes: 169 additions & 51 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import numpy as np
import pandas as pd
from pyspark import SparkContext, cloudpickle
from pyspark import RDD, SparkContext, cloudpickle
from pyspark.ml import Estimator, Model
from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.linalg import VectorUDT
Expand All @@ -44,6 +44,7 @@
MLWritable,
MLWriter,
)
from pyspark.resource import ResourceProfileBuilder, TaskResourceRequests
from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import (
Expand Down Expand Up @@ -88,6 +89,7 @@
_get_rabit_args,
_get_spark_session,
_is_local,
_is_standalone_or_localcluster,
deserialize_booster,
deserialize_xgb_model,
get_class_name,
Expand Down Expand Up @@ -342,6 +344,54 @@ def _gen_predict_params_dict(self) -> Dict[str, Any]:
predict_params[param.name] = self.getOrDefault(param)
return predict_params

def _validate_gpu_params(self) -> None:
"""Validate the gpu parameters and gpu configurations"""

if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
ss = _get_spark_session()
sc = ss.sparkContext

if _is_local(sc):
# Support GPU training in Spark local mode is just for debugging
# purposes, so it's okay for printing the below warning instead of
# checking the real gpu numbers and raising the exception.
get_logger(self.__class__.__name__).warning(
"You have enabled GPU in spark local mode. Please make sure your"
" local node has at least %d GPUs",
self.getOrDefault(self.num_workers),
)
else:
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
if executor_gpus is None:
raise ValueError(
"The `spark.executor.resource.gpu.amount` is required for training"
" on GPU."
)

if not (ss.version >= "3.4.0" and _is_standalone_or_localcluster(sc)):
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
# require spark.task.resource.gpu.amount to be set explicitly
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
if gpu_per_task is not None:
if float(gpu_per_task) < 1.0:
raise ValueError(
"XGBoost doesn't support GPU fractional configurations. "
"Please set `spark.task.resource.gpu.amount=spark.executor"
".resource.gpu.amount`"
)

if float(gpu_per_task) > 1.0:
get_logger(self.__class__.__name__).warning(
"%s GPUs for each Spark task is configured, but each "
"XGBoost training task uses only 1 GPU.",
gpu_per_task,
)
else:
raise ValueError(
"The `spark.task.resource.gpu.amount` is required for training"
" on GPU."
)

def _validate_params(self) -> None:
# pylint: disable=too-many-branches
init_model = self.getOrDefault("xgb_model")
Expand Down Expand Up @@ -421,53 +471,7 @@ def _validate_params(self) -> None:
"`pyspark.ml.linalg.Vector` type."
)

if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
gpu_per_task = (
_get_spark_session()
.sparkContext.getConf()
.get("spark.task.resource.gpu.amount")
)

is_local = _is_local(_get_spark_session().sparkContext)

if is_local:
# checking spark local mode.
if gpu_per_task is not None:
raise RuntimeError(
"The spark local mode does not support gpu configuration."
"Please remove spark.executor.resource.gpu.amount and "
"spark.task.resource.gpu.amount"
)

# Support GPU training in Spark local mode is just for debugging
# purposes, so it's okay for printing the below warning instead of
# checking the real gpu numbers and raising the exception.
get_logger(self.__class__.__name__).warning(
"You have enabled GPU in spark local mode. Please make sure your"
" local node has at least %d GPUs",
self.getOrDefault(self.num_workers),
)
else:
# checking spark non-local mode.
if gpu_per_task is not None:
if float(gpu_per_task) < 1.0:
raise ValueError(
"XGBoost doesn't support GPU fractional configurations. "
"Please set `spark.task.resource.gpu.amount=spark.executor"
".resource.gpu.amount`"
)

if float(gpu_per_task) > 1.0:
get_logger(self.__class__.__name__).warning(
"%s GPUs for each Spark task is configured, but each "
"XGBoost training task uses only 1 GPU.",
gpu_per_task,
)
else:
raise ValueError(
"The `spark.task.resource.gpu.amount` is required for training"
" on GPU."
)
self._validate_gpu_params()


def _validate_and_convert_feature_col_as_float_col_list(
Expand Down Expand Up @@ -592,6 +596,8 @@ def __init__(self) -> None:
arbitrary_params_dict={},
)

self.logger = get_logger(self.__class__.__name__)

def setParams(self, **kwargs: Any) -> None: # pylint: disable=invalid-name
"""
Set params for the estimator.
Expand Down Expand Up @@ -894,6 +900,116 @@ def _get_xgb_parameters(

return booster_params, train_call_kwargs_params, dmatrix_kwargs

def _skip_stage_level_scheduling(self) -> bool:
# pylint: disable=too-many-return-statements
"""Check if stage-level scheduling is not needed,
return true to skip stage-level scheduling"""

if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu):
ss = _get_spark_session()
sc = ss.sparkContext

if ss.version < "3.4.0":
self.logger.info(
"Stage-level scheduling in xgboost requires spark version 3.4.0+"
)
return True

if not _is_standalone_or_localcluster(sc):
self.logger.info(
"Stage-level scheduling in xgboost requires spark standalone or "
"local-cluster mode"
)
return True

executor_cores = sc.getConf().get("spark.executor.cores")
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
if executor_cores is None or executor_gpus is None:
self.logger.info(
"Stage-level scheduling in xgboost requires spark.executor.cores, "
"spark.executor.resource.gpu.amount to be set."
)
return True

if int(executor_cores) == 1:
# there will be only 1 task running at any time.
self.logger.info(
"Stage-level scheduling in xgboost requires spark.executor.cores > 1 "
)
return True

if int(executor_gpus) > 1:
# For spark.executor.resource.gpu.amount > 1, we suppose user knows how to configure
# to make xgboost run successfully.
#
self.logger.info(
"Stage-level scheduling in xgboost will not work "
"when spark.executor.resource.gpu.amount>1"
)
return True

task_gpu_amount = sc.getConf().get("spark.task.resource.gpu.amount")

if task_gpu_amount is None:
# The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
# but with stage-level scheduling, we can make training task grab the gpu.
return False

if float(task_gpu_amount) == float(executor_gpus):
# spark.executor.resource.gpu.amount=spark.task.resource.gpu.amount "
# results in only 1 task running at a time, which may cause perf issue.
return True

# We can enable stage-level scheduling
return False

# CPU training doesn't require stage-level scheduling
return True

def _try_stage_level_scheduling(self, rdd: RDD) -> RDD:
"""Try to enable stage-level scheduling"""

if self._skip_stage_level_scheduling():
return rdd

ss = _get_spark_session()

# executor_cores will not be None
executor_cores = ss.sparkContext.getConf().get("spark.executor.cores")
assert executor_cores is not None

# Spark-rapids is a project to leverage GPUs to accelerate spark SQL.
# If spark-rapids is enabled, to avoid GPU OOM, we don't allow other
# ETL gpu tasks running alongside training tasks.
spark_plugins = ss.conf.get("spark.plugins", " ")
assert spark_plugins is not None
spark_rapids_sql_enabled = ss.conf.get("spark.rapids.sql.enabled", "true")
assert spark_rapids_sql_enabled is not None

task_cores = (
int(executor_cores)
if "com.nvidia.spark.SQLPlugin" in spark_plugins
and "true" == spark_rapids_sql_enabled.lower()
else (int(executor_cores) // 2) + 1
)

# Each training task requires cpu cores > total executor cores//2 + 1 which can
# make sure the tasks be sent to different executors.
#
# Please note that we can't use GPU to limit the concurrent tasks because of
# https://issues.apache.org/jira/browse/SPARK-45527.

task_gpus = 1.0
treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus)
rp = ResourceProfileBuilder().require(treqs).build

self.logger.info(
"XGBoost training tasks require the resource(cores=%s, gpu=%s).",
task_cores,
task_gpus,
)
return rdd.withResources(rp)

def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":
# pylint: disable=too-many-statements, too-many-locals
self._validate_params()
Expand Down Expand Up @@ -994,14 +1110,16 @@ def _train_booster(
)

def _run_job() -> Tuple[str, str]:
ret = (
rdd = (
dataset.mapInPandas(
_train_booster, schema="config string, booster string" # type: ignore
_train_booster, # type: ignore
schema="config string, booster string",
)
.rdd.barrier()
.mapPartitions(lambda x: x)
.collect()[0]
)
rdd_with_resource = self._try_stage_level_scheduling(rdd)
ret = rdd_with_resource.collect()[0]
return ret[0], ret[1]

get_logger("XGBoost-PySpark").info(
Expand Down
7 changes: 7 additions & 0 deletions python-package/xgboost/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ def _is_local(spark_context: SparkContext) -> bool:
return spark_context._jsc.sc().isLocal()


def _is_standalone_or_localcluster(spark_context: SparkContext) -> bool:
master = spark_context.getConf().get("spark.master")
return master is not None and (
master.startswith("spark://") or master.startswith("local-cluster")
)


def _get_gpu_id(task_context: TaskContext) -> int:
"""Get the gpu id from the task resources"""
if task_context is None:
Expand Down

0 comments on commit c5f73d3

Please sign in to comment.