Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition with Parquet filter pushdown modifying shared hadoop Configuration #11676

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from asserts import *
from conftest import is_not_utc
from data_gen import *
from parquet_write_test import parquet_nested_datetime_gen, parquet_ts_write_options
from parquet_write_test import parquet_datetime_gen_simple, parquet_nested_datetime_gen, parquet_ts_write_options
from marks import *
import pyarrow as pa
import pyarrow.parquet as pa_pq
Expand Down Expand Up @@ -361,6 +361,38 @@ def test_parquet_read_roundtrip_datetime_with_legacy_rebase(spark_tmp_path, parq
lambda spark: spark.read.parquet(data_path),
conf=read_confs)


@pytest.mark.skipif(is_not_utc(), reason="LEGACY datetime rebase mode is only supported for UTC timezone")
@pytest.mark.parametrize('parquet_gens', [parquet_datetime_gen_simple], ids=idfn)
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
def test_parquet_read_roundtrip_datetime_with_legacy_rebase_mismatch_files(spark_tmp_path, parquet_gens,
reader_confs, v1_enabled_list):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
data_path = spark_tmp_path + '/PARQUET_DATA'
data_path2 = spark_tmp_path + '/PARQUET_DATA2'
write_confs = {'spark.sql.parquet.datetimeRebaseModeInWrite': 'LEGACY',
'spark.sql.parquet.int96RebaseModeInWrite': 'LEGACY'}
with_cpu_session(
lambda spark: gen_df(spark, gen_list).write.parquet(data_path),
conf=write_confs)
# we want to test having multiple files that have the same column with different
# types - INT96 and INT64 (TIMESTAMP_MICROS)
write_confs2 = {'spark.sql.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.parquet.int96RebaseModeInWrite': 'CORRECTED',
'spark.sql.parquet.outputTimestampType': 'TIMESTAMP_MICROS'}
with_cpu_session(
lambda spark: gen_df(spark, gen_list).write.parquet(data_path2),
conf=write_confs2)

read_confs = copy_and_update(reader_confs,
{'spark.sql.sources.useV1SourceList': v1_enabled_list,
'spark.sql.parquet.datetimeRebaseModeInRead': 'LEGACY',
'spark.sql.parquet.int96RebaseModeInRead': 'LEGACY'})
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path, data_path2).filter("_c0 is not null and _c1 is not null"),
conf=read_confs)

# This is legacy format, which is totally different from datatime legacy rebase mode.
@pytest.mark.parametrize('parquet_gens', [[byte_gen, short_gen, decimal_gen_32bit], decimal_gens,
[ArrayGen(decimal_gen_32bit, max_length=10)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,9 @@ case class GpuParquetMultiFilePartitionReaderFactory(
files: Array[PartitionedFile],
conf: Configuration): PartitionReader[ColumnarBatch] = {
val filterFunc = (file: PartitionedFile) => {
filterHandler.filterBlocks(footerReadType, file, conf,
// we need to copy the Hadoop Configuration because filter push down can mutate it,
// which can affect other threads.
filterHandler.filterBlocks(footerReadType, file, new Configuration(conf),
filters, readDataSchema)
}
val combineConf = CombineConf(combineThresholdSize, combineWaitTime)
Expand Down Expand Up @@ -1234,12 +1236,20 @@ case class GpuParquetMultiFilePartitionReaderFactory(
val tc = TaskContext.get()
val threadPool = MultiFileReaderThreadPool.getOrCreateThreadPool(numThreads)
files.grouped(numFilesFilterParallel).map { fileGroup =>
// we need to copy the Hadoop Configuration because filter push down can mutate it,
// which can affect other threads.
threadPool.submit(
new CoalescingFilterRunner(footerReadType, tc, fileGroup, conf, filters, readDataSchema))
new CoalescingFilterRunner(footerReadType, tc, fileGroup, new Configuration(conf),
filters, readDataSchema))
}.toArray.flatMap(_.get())
} else {
// We need to copy the Hadoop Configuration because filter push down can mutate it. In
// this case we are serially iterating through the files so each one mutating it serially
// doesn't affect the filter of the other files. We just need to make sure it's copied
// once so other tasks don't modify the same conf.
val hadoopConf = new Configuration(conf)
files.map { file =>
filterBlocksForCoalescingReader(footerReadType, file, conf, filters, readDataSchema)
filterBlocksForCoalescingReader(footerReadType, file, hadoopConf, filters, readDataSchema)
}
}
metaAndFilesArr.foreach { metaAndFile =>
Expand Down Expand Up @@ -1326,7 +1336,9 @@ case class GpuParquetPartitionReaderFactory(

private def buildBaseColumnarParquetReader(
file: PartitionedFile): PartitionReader[ColumnarBatch] = {
val conf = broadcastedConf.value.value
// we need to copy the Hadoop Configuration because filter push down can mutate it,
// which can affect other tasks.
val conf = new Configuration(broadcastedConf.value.value)
val startTime = System.nanoTime()
val singleFileInfo = filterHandler.filterBlocks(footerReadType, file, conf, filters,
readDataSchema)
Expand Down
Loading