From 50c83ebe28a431b7eaa399368736db983230a62d Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Thu, 24 Aug 2023 10:53:52 +0800 Subject: [PATCH 1/5] [pyspark] Support stage-level scheduling for training --- python-package/xgboost/spark/core.py | 224 +++++++++++++++++++++------ 1 file changed, 173 insertions(+), 51 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 6b1d2faaacd1..a646e0df88c4 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -22,7 +22,7 @@ import numpy as np import pandas as pd -from pyspark import SparkContext, cloudpickle +from pyspark import RDD, SparkContext, cloudpickle from pyspark.ml import Estimator, Model from pyspark.ml.functions import array_to_vector, vector_to_array from pyspark.ml.linalg import VectorUDT @@ -44,6 +44,7 @@ MLWritable, MLWriter, ) +from pyspark.resource import ResourceProfileBuilder, TaskResourceRequests from pyspark.sql import Column, DataFrame from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct from pyspark.sql.types import ( @@ -342,6 +343,54 @@ def _gen_predict_params_dict(self) -> Dict[str, Any]: predict_params[param.name] = self.getOrDefault(param) return predict_params + def _validate_gpu_params(self) -> None: + """Validate the gpu parameters and gpu configurations""" + + if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): + ss = _get_spark_session() + sc = ss.sparkContext + + if _is_local(sc): + # Support GPU training in Spark local mode is just for debugging + # purposes, so it's okay for printing the below warning instead of + # checking the real gpu numbers and raising the exception. + get_logger(self.__class__.__name__).warning( + "You have enabled GPU in spark local mode. Please make sure your" + " local node has at least %d GPUs", + self.getOrDefault(self.num_workers), + ) + else: + executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount") + if executor_gpus is None: + raise ValueError( + "The `spark.executor.resource.gpu.amount` is required for training" + " on GPU." + ) + + if ss.version < "3.4.0": + # We will enable stage-level scheduling in spark 3.4.0+ which doesn't + # require spark.task.resource.gpu.amount to be set explicitly + gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount") + if gpu_per_task is not None: + if float(gpu_per_task) < 1.0: + raise ValueError( + "XGBoost doesn't support GPU fractional configurations. " + "Please set `spark.task.resource.gpu.amount=spark.executor" + ".resource.gpu.amount`" + ) + + if float(gpu_per_task) > 1.0: + get_logger(self.__class__.__name__).warning( + "%s GPUs for each Spark task is configured, but each " + "XGBoost training task uses only 1 GPU.", + gpu_per_task, + ) + else: + raise ValueError( + "The `spark.task.resource.gpu.amount` is required for training" + " on GPU." + ) + def _validate_params(self) -> None: # pylint: disable=too-many-branches init_model = self.getOrDefault("xgb_model") @@ -421,53 +470,7 @@ def _validate_params(self) -> None: "`pyspark.ml.linalg.Vector` type." ) - if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): - gpu_per_task = ( - _get_spark_session() - .sparkContext.getConf() - .get("spark.task.resource.gpu.amount") - ) - - is_local = _is_local(_get_spark_session().sparkContext) - - if is_local: - # checking spark local mode. - if gpu_per_task is not None: - raise RuntimeError( - "The spark local mode does not support gpu configuration." - "Please remove spark.executor.resource.gpu.amount and " - "spark.task.resource.gpu.amount" - ) - - # Support GPU training in Spark local mode is just for debugging - # purposes, so it's okay for printing the below warning instead of - # checking the real gpu numbers and raising the exception. - get_logger(self.__class__.__name__).warning( - "You have enabled GPU in spark local mode. Please make sure your" - " local node has at least %d GPUs", - self.getOrDefault(self.num_workers), - ) - else: - # checking spark non-local mode. - if gpu_per_task is not None: - if float(gpu_per_task) < 1.0: - raise ValueError( - "XGBoost doesn't support GPU fractional configurations. " - "Please set `spark.task.resource.gpu.amount=spark.executor" - ".resource.gpu.amount`" - ) - - if float(gpu_per_task) > 1.0: - get_logger(self.__class__.__name__).warning( - "%s GPUs for each Spark task is configured, but each " - "XGBoost training task uses only 1 GPU.", - gpu_per_task, - ) - else: - raise ValueError( - "The `spark.task.resource.gpu.amount` is required for training" - " on GPU." - ) + self._validate_gpu_params() def _validate_and_convert_feature_col_as_float_col_list( @@ -592,6 +595,8 @@ def __init__(self) -> None: arbitrary_params_dict={}, ) + self.logger = get_logger(self.__class__.__name__) + def setParams(self, **kwargs: Any) -> None: # pylint: disable=invalid-name """ Set params for the estimator. @@ -894,6 +899,121 @@ def _get_xgb_parameters( return booster_params, train_call_kwargs_params, dmatrix_kwargs + def _skip_stage_level_scheduling(self) -> bool: + # pylint: disable=too-many-return-statements + """Check if stage-level scheduling is not needed, + return true to skip stage-level scheduling""" + + if use_cuda(self.getOrDefault(self.device)) or self.getOrDefault(self.use_gpu): + ss = _get_spark_session() + sc = ss.sparkContext + + if ss.version < "3.4.0": + self.logger.warning( + "Stage-level scheduling in xgboost requires spark version 3.4.0+" + ) + return True + + if _is_local(sc): + # Local mode doesn't support stage-level scheduling + return True + + executor_cores = sc.getConf().get("spark.executor.cores") + executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount") + if executor_cores is None or executor_gpus is None: + self.logger.warning( + "Stage-level scheduling in xgboost requires spark.executor.cores, " + "spark.executor.resource.gpu.amount to be set." + ) + return True + + if int(executor_cores) == 1: + # there will be only 1 task running at any time. + self.logger.warning( + "Stage-level scheduling in xgboost requires spark.executor.cores > 1 " + ) + return True + + if int(executor_gpus) > 1: + # For spark.executor.resource.gpu.amount > 1, we suppose user knows how to configure + # to make xgboost run successfully. + # + self.logger.warning( + "Stage-level scheduling in xgboost will not work " + "when spark.executor.resource.gpu.amount>1" + ) + return True + + task_gpu_amount = sc.getConf().get("spark.task.resource.gpu.amount") + + if task_gpu_amount is None: + # The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set, + # but with stage-level scheduling, we can make training task grab the gpu. + return False + + if float(task_gpu_amount) == float(executor_gpus): + self.logger.warning( + "The configuration of cores (exec=%s, task=%s, runnable tasks=1) will " + "result in wasted resources due to resource gpu limiting the number of " + "runnable tasks per executor to: 1. Please adjust your configuration.", + executor_gpus, + task_gpu_amount, + ) + return True + + # We can enable stage-level scheduling + return False + + # CPU training doesn't require stage-level scheduling + return True + + def _try_stage_level_scheduling(self, rdd: RDD) -> RDD: + """Try to enable stage-level scheduling""" + + if self._skip_stage_level_scheduling(): + return rdd + + ss = _get_spark_session() + + # executor_cores will not be None + executor_cores = ss.sparkContext.getConf().get("spark.executor.cores") + assert executor_cores is not None + + # Each training task requires cpu cores > total executor cores//2 + 1 which can + # make the tasks be sent to different executors. + # + # Please note that we can't set task_cpu_cores to the value which is smaller than + # total executor cores/2 because only task_gpu_amount can't make sure the tasks be + # sent to different executor even task_gpus=1.0 + + # Spark-rapids is a project to leverage GPUs to accelerate spark SQL and it can + # really help the performance. If spark-rapids is enabled. we don't allow other + # ETL gpu tasks running alongside training tasks to avoid OOM + spark_plugins = ss.conf.get("spark.plugins", " ") + assert spark_plugins is not None + spark_rapids_sql_enabled = ss.conf.get("spark.rapids.sql.enabled", "true") + task_cores = ( + int(executor_cores) + if "com.nvidia.spark.SQLPlugin" in spark_plugins + and "true" == spark_rapids_sql_enabled + else (int(executor_cores) // 2) + 1 + ) + + # task_gpus means how many gpu slots the task requires in a single GPU, + # it doesn't mean how many gpu shares it would like to require, so we + # can set it to any value of (0, 0.5] or 1. + task_gpus = 1.0 + + treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus) + rp = ResourceProfileBuilder().require(treqs).build + + self.logger.info( + "XGBoost training tasks require the resource(cores=%s, gpu=%s).", + task_cores, + task_gpus, + ) + return rdd.withResources(rp) + def _fit(self, dataset: DataFrame) -> "_SparkXGBModel": # pylint: disable=too-many-statements, too-many-locals self._validate_params() @@ -994,14 +1114,16 @@ def _train_booster( ) def _run_job() -> Tuple[str, str]: - ret = ( + rdd = ( dataset.mapInPandas( - _train_booster, schema="config string, booster string" # type: ignore + _train_booster, # type: ignore + schema="config string, booster string", ) .rdd.barrier() .mapPartitions(lambda x: x) - .collect()[0] ) + rdd_with_resource = self._try_stage_level_scheduling(rdd) + ret = rdd_with_resource.collect()[0] return ret[0], ret[1] get_logger("XGBoost-PySpark").info( From a2652367555b0c5bb97a260f23da3e969e715347 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 30 Aug 2023 12:05:26 +0800 Subject: [PATCH 2/5] comments --- python-package/xgboost/spark/core.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index a646e0df88c4..d7f647ce66a4 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -929,7 +929,7 @@ def _skip_stage_level_scheduling(self) -> bool: if int(executor_cores) == 1: # there will be only 1 task running at any time. - self.logger.warning( + self.logger.info( "Stage-level scheduling in xgboost requires spark.executor.cores > 1 " ) return True @@ -991,7 +991,9 @@ def _try_stage_level_scheduling(self, rdd: RDD) -> RDD: # ETL gpu tasks running alongside training tasks to avoid OOM spark_plugins = ss.conf.get("spark.plugins", " ") assert spark_plugins is not None - spark_rapids_sql_enabled = ss.conf.get("spark.rapids.sql.enabled", "true") + spark_rapids_sql_enabled = ss.conf.get( + "spark.rapids.sql.enabled", "true" + ).lower() task_cores = ( int(executor_cores) if "com.nvidia.spark.SQLPlugin" in spark_plugins @@ -1004,15 +1006,16 @@ def _try_stage_level_scheduling(self, rdd: RDD) -> RDD: # can set it to any value of (0, 0.5] or 1. task_gpus = 1.0 - treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus) - rp = ResourceProfileBuilder().require(treqs).build + # treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus) + # rp = ResourceProfileBuilder().require(treqs).build self.logger.info( "XGBoost training tasks require the resource(cores=%s, gpu=%s).", task_cores, task_gpus, ) - return rdd.withResources(rp) + # return rdd.withResources(rp) + return rdd def _fit(self, dataset: DataFrame) -> "_SparkXGBModel": # pylint: disable=too-many-statements, too-many-locals From c8102f13820f0c5fb573d9cced1cf1be22ed8b59 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 30 Aug 2023 12:35:51 +0800 Subject: [PATCH 3/5] avoid stage-level scheduling on non spark standalone and localcluster mode --- python-package/xgboost/spark/core.py | 19 +++++++++++-------- python-package/xgboost/spark/utils.py | 7 +++++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index d7f647ce66a4..58096e3e5a81 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -44,7 +44,6 @@ MLWritable, MLWriter, ) -from pyspark.resource import ResourceProfileBuilder, TaskResourceRequests from pyspark.sql import Column, DataFrame from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct from pyspark.sql.types import ( @@ -89,6 +88,7 @@ _get_rabit_args, _get_spark_session, _is_local, + _is_standalone_or_localcluster, deserialize_booster, deserialize_xgb_model, get_class_name, @@ -367,7 +367,7 @@ def _validate_gpu_params(self) -> None: " on GPU." ) - if ss.version < "3.4.0": + if not (ss.version >= "3.4.0" and _is_standalone_or_localcluster(sc)): # We will enable stage-level scheduling in spark 3.4.0+ which doesn't # require spark.task.resource.gpu.amount to be set explicitly gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount") @@ -914,8 +914,11 @@ def _skip_stage_level_scheduling(self) -> bool: ) return True - if _is_local(sc): - # Local mode doesn't support stage-level scheduling + if not _is_standalone_or_localcluster(sc): + self.logger.warning( + "Stage-level scheduling in xgboost requires spark standalone or " + "local-cluster mode" + ) return True executor_cores = sc.getConf().get("spark.executor.cores") @@ -991,13 +994,13 @@ def _try_stage_level_scheduling(self, rdd: RDD) -> RDD: # ETL gpu tasks running alongside training tasks to avoid OOM spark_plugins = ss.conf.get("spark.plugins", " ") assert spark_plugins is not None - spark_rapids_sql_enabled = ss.conf.get( - "spark.rapids.sql.enabled", "true" - ).lower() + spark_rapids_sql_enabled = ss.conf.get("spark.rapids.sql.enabled", "true") + assert spark_rapids_sql_enabled is not None + task_cores = ( int(executor_cores) if "com.nvidia.spark.SQLPlugin" in spark_plugins - and "true" == spark_rapids_sql_enabled + and "true" == spark_rapids_sql_enabled.lower() else (int(executor_cores) // 2) + 1 ) diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 66d7ca4548ca..395865386191 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -129,6 +129,13 @@ def _is_local(spark_context: SparkContext) -> bool: return spark_context._jsc.sc().isLocal() +def _is_standalone_or_localcluster(spark_context: SparkContext) -> bool: + master = spark_context.getConf().get("spark.master") + return master is not None and ( + master.startswith("spark://") or master.startswith("local-cluster") + ) + + def _get_gpu_id(task_context: TaskContext) -> int: """Get the gpu id from the task resources""" if task_context is None: From 2cb7538430bd779a8ef1e4ff09896c12ad894a4f Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 16 Oct 2023 17:35:40 +0800 Subject: [PATCH 4/5] resolve comments --- python-package/xgboost/spark/core.py | 31 ++++++++++++---------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 58096e3e5a81..8be35ad26b9e 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -44,6 +44,7 @@ MLWritable, MLWriter, ) +from pyspark.resource import ResourceProfileBuilder, TaskResourceRequests from pyspark.sql import Column, DataFrame from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct from pyspark.sql.types import ( @@ -982,16 +983,9 @@ def _try_stage_level_scheduling(self, rdd: RDD) -> RDD: executor_cores = ss.sparkContext.getConf().get("spark.executor.cores") assert executor_cores is not None - # Each training task requires cpu cores > total executor cores//2 + 1 which can - # make the tasks be sent to different executors. - # - # Please note that we can't set task_cpu_cores to the value which is smaller than - # total executor cores/2 because only task_gpu_amount can't make sure the tasks be - # sent to different executor even task_gpus=1.0 - - # Spark-rapids is a project to leverage GPUs to accelerate spark SQL and it can - # really help the performance. If spark-rapids is enabled. we don't allow other - # ETL gpu tasks running alongside training tasks to avoid OOM + # Spark-rapids is a project to leverage GPUs to accelerate spark SQL. + # If spark-rapids is enabled, to avoid GPU OOM, we don't allow other + # ETL gpu tasks running alongside training tasks. spark_plugins = ss.conf.get("spark.plugins", " ") assert spark_plugins is not None spark_rapids_sql_enabled = ss.conf.get("spark.rapids.sql.enabled", "true") @@ -1004,21 +998,22 @@ def _try_stage_level_scheduling(self, rdd: RDD) -> RDD: else (int(executor_cores) // 2) + 1 ) - # task_gpus means how many gpu slots the task requires in a single GPU, - # it doesn't mean how many gpu shares it would like to require, so we - # can set it to any value of (0, 0.5] or 1. - task_gpus = 1.0 + # Each training task requires cpu cores > total executor cores//2 + 1 which can + # make sure the tasks be sent to different executors. + # + # Please note that we can't use GPU to limit the concurrent tasks because of + # https://issues.apache.org/jira/browse/SPARK-45527. - # treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus) - # rp = ResourceProfileBuilder().require(treqs).build + task_gpus = 1.0 + treqs = TaskResourceRequests().cpus(task_cores).resource("gpu", task_gpus) + rp = ResourceProfileBuilder().require(treqs).build self.logger.info( "XGBoost training tasks require the resource(cores=%s, gpu=%s).", task_cores, task_gpus, ) - # return rdd.withResources(rp) - return rdd + return rdd.withResources(rp) def _fit(self, dataset: DataFrame) -> "_SparkXGBModel": # pylint: disable=too-many-statements, too-many-locals From 7b551086caf9bf492fdaf1c65fd5c0a88dc60187 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 16 Oct 2023 20:42:23 +0800 Subject: [PATCH 5/5] comments --- python-package/xgboost/spark/core.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 8be35ad26b9e..9fe73005a073 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -910,13 +910,13 @@ def _skip_stage_level_scheduling(self) -> bool: sc = ss.sparkContext if ss.version < "3.4.0": - self.logger.warning( + self.logger.info( "Stage-level scheduling in xgboost requires spark version 3.4.0+" ) return True if not _is_standalone_or_localcluster(sc): - self.logger.warning( + self.logger.info( "Stage-level scheduling in xgboost requires spark standalone or " "local-cluster mode" ) @@ -925,7 +925,7 @@ def _skip_stage_level_scheduling(self) -> bool: executor_cores = sc.getConf().get("spark.executor.cores") executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount") if executor_cores is None or executor_gpus is None: - self.logger.warning( + self.logger.info( "Stage-level scheduling in xgboost requires spark.executor.cores, " "spark.executor.resource.gpu.amount to be set." ) @@ -942,7 +942,7 @@ def _skip_stage_level_scheduling(self) -> bool: # For spark.executor.resource.gpu.amount > 1, we suppose user knows how to configure # to make xgboost run successfully. # - self.logger.warning( + self.logger.info( "Stage-level scheduling in xgboost will not work " "when spark.executor.resource.gpu.amount>1" ) @@ -956,13 +956,8 @@ def _skip_stage_level_scheduling(self) -> bool: return False if float(task_gpu_amount) == float(executor_gpus): - self.logger.warning( - "The configuration of cores (exec=%s, task=%s, runnable tasks=1) will " - "result in wasted resources due to resource gpu limiting the number of " - "runnable tasks per executor to: 1. Please adjust your configuration.", - executor_gpus, - task_gpu_amount, - ) + # spark.executor.resource.gpu.amount=spark.task.resource.gpu.amount " + # results in only 1 task running at a time, which may cause perf issue. return True # We can enable stage-level scheduling