Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shuffling an empty Struct() column with UCX #8801

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,10 @@ class RapidsShuffleClient(
val ptrs = new ArrayBuffer[PendingTransferRequest](allTables)
(0 until allTables).foreach { i =>
val tableMeta = ShuffleMetadata.copyTableMetaToHeap(metaResponse.tableMetas(i))
if (tableMeta.bufferMeta() != null) {

// We check the uncompressedSize to make sure we don't request a 0-sized buffer
// from a peer. We treat such a corner case as a degenerate batch
if (tableMeta.bufferMeta() != null && tableMeta.bufferMeta().uncompressedSize() > 0) {
ptrs += PendingTransferRequest(
this,
tableMeta,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,12 @@ class RapidsCachingWriter[K, V](

private val uncompressedMetric: SQLMetric = metrics("dataSize")

// This is here for the special case where we have no columns like with the .count
// case or when we have 0-byte columns. We pick 100 as an arbitrary number so that
// we can shuffle these degenerate batches, which have valid metadata and should be
// used on the reducer side for computation.
private val DEGENERATE_PARTITION_BYTE_SIZE_DEFAULT: Long = 100L

override def write(records: Iterator[Product2[K, V]]): Unit = {
// NOTE: This MUST NOT CLOSE the incoming batches because they are
// closed by the input iterator generated by GpuShuffleExchangeExec
Expand Down Expand Up @@ -956,7 +962,15 @@ class RapidsCachingWriter[K, V](
throw new IllegalStateException(s"Unexpected column type: ${c.getClass}")
}
bytesWritten += partSize
sizes(partId) += partSize
// if the size is 0 and we have rows, we are in a case where there are columns
// but the type is such that there isn't a buffer in the GPU backing it.
// For example, a Struct column without any members. We treat such a case as if it
// were a degenerate table.
if (partSize == 0 && batch.numRows() > 0) {
sizes(partId) += DEGENERATE_PARTITION_BYTE_SIZE_DEFAULT
} else {
sizes(partId) += partSize
}
handle
} else {
// no device data, tracking only metadata
Expand All @@ -966,13 +980,10 @@ class RapidsCachingWriter[K, V](
blockId,
tableMeta)

// The size of the data is really only used to tell if the data should be shuffled or not
// a 0 indicates that we should not shuffle anything. This is here for the special case
// where we have no columns, because of predicate push down, but we have a row count as
// metadata. We still want to shuffle it. The 100 is an arbitrary number and can be
// any non-zero number that is not too large.
// ensure that we set the partition size to the default in this case if
// we have non-zero rows, so this degenerate batch is shuffled.
if (batch.numRows > 0) {
sizes(partId) += 100
sizes(partId) += DEGENERATE_PARTITION_BYTE_SIZE_DEFAULT
}
handle
}
Expand Down