Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shuffle coalesce read supports split-and-retry #11598

Open
wants to merge 6 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,8 @@ def test_hash_groupby_typed_imperative_agg_without_gpu_implementation_fallback()
@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114
@pytest.mark.parametrize('data_gen', _init_list, ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
def test_hash_multiple_mode_query(data_gen, conf):
@pytest.mark.parametrize('shuffle_split', [True, False], ids=idfn)
def test_hash_multiple_mode_query(data_gen, conf, shuffle_split):
print_params(data_gen)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100)
Expand All @@ -1132,7 +1133,10 @@ def test_hash_multiple_mode_query(data_gen, conf):
f.max('a'),
f.sumDistinct('b'),
f.countDistinct('c')
), conf=conf)
),
conf=copy_and_update(
conf,
{'spark.rapids.shuffle.splitRetryRead.enabled': shuffle_split}))


@approximate_float
Expand Down
8 changes: 6 additions & 2 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,21 +242,25 @@ def test_hash_join_ridealong_non_sized(data_gen, join_type, sub_part_enabled):
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', basic_nested_gens + [decimal_gen_128bit], ids=idfn)
@pytest.mark.parametrize('join_type', all_symmetric_sized_join_types, ids=idfn)
@pytest.mark.parametrize('shuffle_split', [True, False], ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_hash_join_ridealong_symmetric(data_gen, join_type):
def test_hash_join_ridealong_symmetric(data_gen, join_type, shuffle_split):
confs = {
"spark.rapids.sql.join.useShuffledSymmetricHashJoin": "true",
"spark.rapids.shuffle.splitRetryRead.enabled": shuffle_split,
}
hash_join_ridealong(data_gen, join_type, confs)

@validate_execs_in_gpu_plan('GpuShuffledAsymmetricHashJoinExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', basic_nested_gens + [decimal_gen_128bit], ids=idfn)
@pytest.mark.parametrize('join_type', all_asymmetric_sized_join_types, ids=idfn)
@pytest.mark.parametrize('shuffle_split', [True, False], ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_hash_join_ridealong_asymmetric(data_gen, join_type):
def test_hash_join_ridealong_asymmetric(data_gen, join_type, shuffle_split):
confs = {
"spark.rapids.sql.join.useShuffledAsymmetricHashJoin": "true",
"spark.rapids.shuffle.splitRetryRead.enabled": shuffle_split,
}
hash_join_ridealong(data_gen, join_type, confs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@ package com.nvidia.spark.rapids

import java.util

import scala.collection.mutable
import scala.reflect.ClassTag

import ai.rapids.cudf.{JCudfSerialization, NvtxColor, NvtxRange}
import ai.rapids.cudf.JCudfSerialization.HostConcatResult
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits.{AutoCloseableProducingSeq, AutoCloseableSeq}
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitTargetSizeInHalfGpu, withRetry}
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
import com.nvidia.spark.rapids.shims.ShimUnaryExecNode

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
Expand Down Expand Up @@ -77,13 +81,13 @@ case class GpuShuffleCoalesceExec(child: SparkPlan, targetBatchByteSize: Long)
}
}

/** A case class to pack some options. Now it has only one, but may have more in the future */
case class CoalesceReadOption private(kudoEnabled: Boolean)
/** A case class to pack some options. */
case class CoalesceReadOption private(kudoEnabled: Boolean, useSplitRetryRead: Boolean)

object CoalesceReadOption {
def apply(conf: RapidsConf): CoalesceReadOption = {
// TODO get the value from conf
CoalesceReadOption(false)
// TODO get the value of Kudo support from conf
CoalesceReadOption(kudoEnabled = false, conf.shuffleSplitRetryReadEnabled)
}
}

Expand Down Expand Up @@ -112,23 +116,39 @@ object GpuShuffleCoalesceUtils {
readOption: CoalesceReadOption,
metricsMap: Map[String, GpuMetric],
prefetchFirstBatch: Boolean = false): Iterator[ColumnarBatch] = {
val hostIter = if (readOption.kudoEnabled) {
// TODO replace with the actual Kudo host iterator
throw new UnsupportedOperationException("Kudo is not supported yet")
} else {
new HostShuffleCoalesceIterator(iter, targetSize, metricsMap)
}
val maybeBufferedIter = if (prefetchFirstBatch) {
val bufferedIter = new CloseableBufferedIterator(hostIter)
withResource(new NvtxRange("fetch first batch", NvtxColor.YELLOW)) { _ =>
// Force a coalesce of the first batch before we grab the GPU semaphore
bufferedIter.headOption
if (readOption.useSplitRetryRead) {
val reader = if (readOption.kudoEnabled) {
// TODO replace with the actual Kudo host iterator
throw new UnsupportedOperationException("Kudo is not supported yet")
} else {
new GpuShuffleCoalesceReader(iter, targetSize, dataTypes, metricsMap)
}
bufferedIter
if (prefetchFirstBatch) {
withResource(new NvtxRange("fetch first batch", NvtxColor.YELLOW)) { _ =>
// Force a coalesce of the first batch before we grab the GPU semaphore
reader.prefetchHeadOnHost()
}
}
reader.asIterator
} else {
hostIter
val hostIter = if (readOption.kudoEnabled) {
// TODO replace with the actual Kudo host iterator
throw new UnsupportedOperationException("Kudo is not supported yet")
} else {
new HostShuffleCoalesceIterator(iter, targetSize, metricsMap)
}
val maybeBufferedIter = if (prefetchFirstBatch) {
val bufferedIter = new CloseableBufferedIterator(hostIter)
withResource(new NvtxRange("fetch first batch", NvtxColor.YELLOW)) { _ =>
// Force a coalesce of the first batch before we grab the GPU semaphore
bufferedIter.headOption
}
bufferedIter
} else {
hostIter
}
new GpuShuffleCoalesceIterator(maybeBufferedIter, dataTypes, metricsMap)
}
new GpuShuffleCoalesceIterator(maybeBufferedIter, dataTypes, metricsMap)
}

/** Get the buffer size of a serialized batch just returned by the Shuffle deserializer */
Expand Down Expand Up @@ -194,6 +214,166 @@ class JCudfTableOperator extends SerializedTableOperator[SerializedTableColumn]
}
}

/**
* Reader playing the same role as the combination of "HostCoalesceIteratorBase" and
* "GpuShuffleCoalesceIterator". That is to coalesce columnar batches expected to
* contain only serialized tables T from Shuffle. The serialized tables within are
* collected up to the target batch size and then concatenated them on the host.
* Next try to send the concatenated result to GPU.
*
* When OOM happens, it will reduce the target size by half, try to concatenate
* half of cached tables and send the result to GPU again.
*/
abstract class GpuShuffleCoalesceReaderBase[T <: AutoCloseable: ClassTag](
iter: Iterator[ColumnarBatch],
targetBatchSize: Long,
dataTypes: Array[DataType],
metricsMap: Map[String, GpuMetric]) extends AutoCloseable with Logging {
private[this] val opTimeMetric = metricsMap(GpuMetric.OP_TIME)
private[this] val concatTimeMetric = metricsMap(GpuMetric.CONCAT_TIME)
private[this] val inputBatchesMetric = metricsMap(GpuMetric.NUM_INPUT_BATCHES)
private[this] val inputRowsMetric = metricsMap(GpuMetric.NUM_INPUT_ROWS)
private[this] val outputBatchesMetric = metricsMap(GpuMetric.NUM_OUTPUT_BATCHES)
private[this] val outputRowsMetric = metricsMap(GpuMetric.NUM_OUTPUT_ROWS)

private[this] val serializedTables = new mutable.Queue[T]
private[this] var realBatchSize = math.max(targetBatchSize, 1)
private[this] var closed = false

protected def tableOperator: SerializedTableOperator[T]

// Don't install the callback if in a unit test
Option(TaskContext.get()).foreach { tc =>
onTaskCompletion(tc)(close())
}

override def close(): Unit = if (!closed) {
serializedTables.safeClose()
serializedTables.clear()
closed = true
}

/** Pull in batches from the input to make sure enough batches in the cache. */
private def pullNextBatch(): Boolean = {
if (closed) return false
// Always make sure enough data has been cached for the next batch.
var curCachedSize = serializedTables.map(tableOperator.getDataLen).sum
var curCachedRows = serializedTables.map(tableOperator.getNumRows(_).toLong).sum
while (iter.hasNext && curCachedSize < realBatchSize && curCachedRows < Int.MaxValue) {
closeOnExcept(iter.next()) { batch =>
inputBatchesMetric += 1
inputRowsMetric += batch.numRows()
if (batch.numRows > 0) {
val tableCol = batch.column(0).asInstanceOf[T]
serializedTables.enqueue(tableCol)
curCachedSize += tableOperator.getDataLen(tableCol)
curCachedRows += tableOperator.getNumRows(tableCol)
} else {
batch.close()
}
}
}
serializedTables.nonEmpty
}

/** Collect batches that the total size is up to the given size from the cache. */
private def collectTablesForNextBatch(targetSize: Long): Array[T] = {
var curSize = 0L
var curRows = 0L
val taken = serializedTables.takeWhile { tableCol =>
curSize += tableOperator.getDataLen(tableCol)
curRows += tableOperator.getNumRows(tableCol)
curSize <= targetSize && curRows < Int.MaxValue
}
if (taken.isEmpty) {
// The first batch size is bigger than targetSize, always take it
Array(serializedTables.head)
} else {
taken.toArray
}
}

private val reduceBatchSizeByHalf: AutoCloseableTargetSize => Seq[AutoCloseableTargetSize] =
batchSize => {
val halfSize = splitTargetSizeInHalfGpu(batchSize)
assert(halfSize.length == 1)
// Remember the size for the following caching and collecting.
logDebug(s"Update target batch size from $realBatchSize to ${halfSize.head.targetSize}")
realBatchSize = halfSize.head.targetSize
halfSize
}

private def buildNextBatch(): ColumnarBatch = {
val closeableBatchSize = AutoCloseableTargetSize(realBatchSize, 1)
val iter = withRetry(closeableBatchSize, reduceBatchSizeByHalf) { attemptSize =>
val (concatRet, numTables) = withResource(new MetricRange(opTimeMetric)) { _ =>
// Retry steps:
// 1) Collect tables from cache for the next batch according to the target size.
// 2) Concatenate the collected tables
// 3) Move the concatenated result to GPU
// We have to re-collect the tables and re-concatenate them, because the
// coalesced result can not be split into smaller pieces.
val curTables = collectTablesForNextBatch(attemptSize.targetSize)
val concatHostBatch = withResource(new MetricRange(concatTimeMetric)) { _ =>
tableOperator.concatOnHost(curTables)
}
(concatHostBatch, curTables.length)
}
withResource(concatRet) { _ =>
// Begin to use GPU
GpuSemaphore.acquireIfNecessary(TaskContext.get())
withResource(new MetricRange(opTimeMetric)) { _ =>
(concatRet.toGpuBatch(dataTypes), numTables)
}
}
}
// Expect only one batch
val (batch, numTables) = iter.next()
closeOnExcept(batch) { _ =>
assert(iter.isEmpty)
// Now it is ok to remove the first numTables table from cache.
(0 until numTables).safeMap(_ => serializedTables.dequeue()).safeClose()
batch
}
}

/**
* Prefetch the first bundle of serialized batches with the total size up to the
* "targetSize". The prefetched batches will be cached on host until the "next()"
* is called. This is for some optimization cases in joins.
*/
def prefetchHeadOnHost(): this.type = {
if (serializedTables.isEmpty) {
pullNextBatch()
}
this
}

def asIterator: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] {
override def hasNext: Boolean = pullNextBatch()
override def next(): ColumnarBatch = {
if (!hasNext) {
throw new NoSuchElementException("No more host batches to read")
}
val batch = buildNextBatch()
outputBatchesMetric += 1
outputRowsMetric += batch.numRows()
batch
}
}
}

class GpuShuffleCoalesceReader(
iter: Iterator[ColumnarBatch],
targetBatchSize: Long,
dataTypes: Array[DataType],
metricsMap: Map[String, GpuMetric])
extends GpuShuffleCoalesceReaderBase[SerializedTableColumn](iter, targetBatchSize,
dataTypes, metricsMap) {

override protected val tableOperator = new JCudfTableOperator
}

/**
* Iterator that coalesces columnar batches that are expected to only contain
* serialized tables from shuffle. The serialized tables within are collected up
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -880,9 +880,17 @@ object GpuShuffledAsymmetricHashJoinExec {
exprs.buildSideNeedsNullFilter, metrics)
JoinInfo(joinType, buildSide, buildIter, buildSize, None, streamIter, exprs)
} else {
val buildBatch = getSingleBuildBatch(baseBuildIter, exprs, metrics)
val buildIter = new SingleGpuColumnarBatchIterator(buildBatch)
val buildStats = JoinBuildSideStats.fromBatch(buildBatch, exprs.boundBuildKeys)
val nullFilteredBuildIter = addNullFilterIfNecessary(baseBuildIter,
exprs.boundBuildKeys, exprs.buildSideNeedsNullFilter, metrics)
val buildQueue = mutable.Queue.empty[SpillableColumnarBatch]
val buildStats = closeOnExcept(buildQueue) { _ =>
while (nullFilteredBuildIter.hasNext) {
buildQueue += SpillableColumnarBatch(nullFilteredBuildIter.next(),
SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
}
JoinBuildSideStats.fromBatches(buildQueue.toSeq, exprs.boundBuildKeys)
}
val buildIter = new SpillableColumnarBatchQueueIterator(buildQueue, Iterator.empty)
if (buildStats.streamMagnificationFactor < magnificationThreshold) {
metrics(BUILD_DATA_SIZE).set(buildSize)
JoinInfo(joinType, buildSide, buildIter, buildSize, Some(buildStats), streamIter,
Expand Down Expand Up @@ -1016,18 +1024,6 @@ object GpuShuffledAsymmetricHashJoinExec {
buildIter
}
}

private def getSingleBuildBatch(
baseIter: Iterator[ColumnarBatch],
exprs: BoundJoinExprs,
metrics: Map[String, GpuMetric]): ColumnarBatch = {
val iter = addNullFilterIfNecessary(baseIter, exprs.boundBuildKeys,
exprs.buildSideNeedsNullFilter, metrics)
closeOnExcept(iter.next()) { batch =>
assert(!iter.hasNext)
batch
}
}
}

class HostHostAsymmetricJoinSizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,13 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression.
.integerConf
.createWithDefault(20)

val SHUFFLE_SPLITRETRY_READ = conf("spark.rapids.shuffle.splitRetryRead.enabled")
.doc("When set to true, use the resizeable shuffle reader who will reduce the " +
"target batch size by half when getting OOM when doing coalescing shuffle read.")
.internal()
.booleanConf
.createWithDefault(true)

// ALLUXIO CONFIGS
val ALLUXIO_MASTER = conf("spark.rapids.alluxio.master")
.doc("The Alluxio master hostname. If not set, read Alluxio master URL from " +
Expand Down Expand Up @@ -3233,6 +3240,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val caseWhenFuseEnabled: Boolean = get(CASE_WHEN_FUSE)

lazy val shuffleSplitRetryReadEnabled: Boolean = get(SHUFFLE_SPLITRETRY_READ)

private val optimizerDefaults = Map(
// this is not accurate because CPU projections do have a cost due to appending values
// to each row that is produced, but this needs to be a really small number because
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -631,7 +631,7 @@ object RmmRapidsRetryIterator extends Logging {
clearInjectedOOMIfNeeded()

// make sure we add any prior exceptions to this one as causes
if (lastException != null) {
if (lastException != null && lastException != ex) {
ex.addSuppressed(lastException)
}
lastException = ex
Expand Down
Loading
Loading