From 1453281ffddcb4cfc41188dca88682727ed8e121 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 25 Jul 2023 17:41:43 +0000 Subject: [PATCH 1/2] Fix shuffling an empty Struct() column with UCX Signed-off-by: Alessandro Bellina --- .../spark/rapids/shuffle/RapidsShuffleClient.scala | 5 ++++- .../sql/rapids/RapidsShuffleInternalManagerBase.scala | 10 +++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index 0f8bc5b1b13..9aecedf064f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -329,7 +329,10 @@ class RapidsShuffleClient( val ptrs = new ArrayBuffer[PendingTransferRequest](allTables) (0 until allTables).foreach { i => val tableMeta = ShuffleMetadata.copyTableMetaToHeap(metaResponse.tableMetas(i)) - if (tableMeta.bufferMeta() != null) { + + // We check the uncompressedSize to make sure we don't request a 0-sized buffer + // from a peer. We treat such a corner case as a degenerate batch + if (tableMeta.bufferMeta() != null && tableMeta.bufferMeta().uncompressedSize() > 0) { ptrs += PendingTransferRequest( this, tableMeta, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index 6d967ae2ec7..125087fe79c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -956,7 +956,15 @@ class RapidsCachingWriter[K, V]( throw new IllegalStateException(s"Unexpected column type: ${c.getClass}") } bytesWritten += partSize - sizes(partId) += partSize + // if the size is 0 and we have rows, we are in a case where there are columns + // but the type is such that there isn't a buffer in the GPU backing it. + // For example, a Struct column without any members. We treat such a case as if it + // were a degenerate table. + if (partSize == 0 && batch.numRows() > 0) { + sizes(partId) += 100 + } else { + sizes(partId) += partSize + } handle } else { // no device data, tracking only metadata From b15987077d278ca2573463dd996ff20c961bc42a Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 26 Jul 2023 08:29:51 -0500 Subject: [PATCH 2/2] Clarify the significance of 100 as a degenerate partition size --- .../RapidsShuffleInternalManagerBase.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index 125087fe79c..b32e0d3959d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -910,6 +910,12 @@ class RapidsCachingWriter[K, V]( private val uncompressedMetric: SQLMetric = metrics("dataSize") + // This is here for the special case where we have no columns like with the .count + // case or when we have 0-byte columns. We pick 100 as an arbitrary number so that + // we can shuffle these degenerate batches, which have valid metadata and should be + // used on the reducer side for computation. + private val DEGENERATE_PARTITION_BYTE_SIZE_DEFAULT: Long = 100L + override def write(records: Iterator[Product2[K, V]]): Unit = { // NOTE: This MUST NOT CLOSE the incoming batches because they are // closed by the input iterator generated by GpuShuffleExchangeExec @@ -961,7 +967,7 @@ class RapidsCachingWriter[K, V]( // For example, a Struct column without any members. We treat such a case as if it // were a degenerate table. if (partSize == 0 && batch.numRows() > 0) { - sizes(partId) += 100 + sizes(partId) += DEGENERATE_PARTITION_BYTE_SIZE_DEFAULT } else { sizes(partId) += partSize } @@ -974,13 +980,10 @@ class RapidsCachingWriter[K, V]( blockId, tableMeta) - // The size of the data is really only used to tell if the data should be shuffled or not - // a 0 indicates that we should not shuffle anything. This is here for the special case - // where we have no columns, because of predicate push down, but we have a row count as - // metadata. We still want to shuffle it. The 100 is an arbitrary number and can be - // any non-zero number that is not too large. + // ensure that we set the partition size to the default in this case if + // we have non-zero rows, so this degenerate batch is shuffled. if (batch.numRows > 0) { - sizes(partId) += 100 + sizes(partId) += DEGENERATE_PARTITION_BYTE_SIZE_DEFAULT } handle }