diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index 0f8bc5b1b13..9aecedf064f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -329,7 +329,10 @@ class RapidsShuffleClient( val ptrs = new ArrayBuffer[PendingTransferRequest](allTables) (0 until allTables).foreach { i => val tableMeta = ShuffleMetadata.copyTableMetaToHeap(metaResponse.tableMetas(i)) - if (tableMeta.bufferMeta() != null) { + + // We check the uncompressedSize to make sure we don't request a 0-sized buffer + // from a peer. We treat such a corner case as a degenerate batch + if (tableMeta.bufferMeta() != null && tableMeta.bufferMeta().uncompressedSize() > 0) { ptrs += PendingTransferRequest( this, tableMeta, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index 6d967ae2ec7..b32e0d3959d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -910,6 +910,12 @@ class RapidsCachingWriter[K, V]( private val uncompressedMetric: SQLMetric = metrics("dataSize") + // This is here for the special case where we have no columns like with the .count + // case or when we have 0-byte columns. We pick 100 as an arbitrary number so that + // we can shuffle these degenerate batches, which have valid metadata and should be + // used on the reducer side for computation. + private val DEGENERATE_PARTITION_BYTE_SIZE_DEFAULT: Long = 100L + override def write(records: Iterator[Product2[K, V]]): Unit = { // NOTE: This MUST NOT CLOSE the incoming batches because they are // closed by the input iterator generated by GpuShuffleExchangeExec @@ -956,7 +962,15 @@ class RapidsCachingWriter[K, V]( throw new IllegalStateException(s"Unexpected column type: ${c.getClass}") } bytesWritten += partSize - sizes(partId) += partSize + // if the size is 0 and we have rows, we are in a case where there are columns + // but the type is such that there isn't a buffer in the GPU backing it. + // For example, a Struct column without any members. We treat such a case as if it + // were a degenerate table. + if (partSize == 0 && batch.numRows() > 0) { + sizes(partId) += DEGENERATE_PARTITION_BYTE_SIZE_DEFAULT + } else { + sizes(partId) += partSize + } handle } else { // no device data, tracking only metadata @@ -966,13 +980,10 @@ class RapidsCachingWriter[K, V]( blockId, tableMeta) - // The size of the data is really only used to tell if the data should be shuffled or not - // a 0 indicates that we should not shuffle anything. This is here for the special case - // where we have no columns, because of predicate push down, but we have a row count as - // metadata. We still want to shuffle it. The 100 is an arbitrary number and can be - // any non-zero number that is not too large. + // ensure that we set the partition size to the default in this case if + // we have non-zero rows, so this degenerate batch is shuffled. if (batch.numRows > 0) { - sizes(partId) += 100 + sizes(partId) += DEGENERATE_PARTITION_BYTE_SIZE_DEFAULT } handle }