Skip to content

Commit

Permalink
[SPARK-50286][SQL] Correctly propagate SQL options to WriteBuilder
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

SPARK-49098 introduced a SQL syntax to allow users to set table options on DSv2 write cases, but unfortunately, the options set by SQL are not propagated correctly to the underlying DSv2 `WriteBuilder`

```
INSERT INTO $t1 WITH (`write.split-size` = 10) SELECT ...
```

```
df.writeTo(t1).option("write.split-size", "10").append()
```

From the user's perspective, the above two are equivalent, but internal implementations differ slightly. Both of them are going to construct an

```
AppendData(r: DataSourceV2Relation, ..., writeOptions, ...)
```

but the SQL `options` are carried by `r.options`, and the `DataFrame` API `options` are carried by `writeOptions`. Currently, only the latter is propagated to the `WriteBuilder`, and the former is silently dropped. This PR fixes the above issue by merging those two `options`.

Currently, the `options` propagation is inconsistent in `DataFrame`, `DataFrameV2`, and SQL:
- DataFrame API, the same `options` are carried by both `writeOptions` and `DataSourceV2Relation`
- DataFrameV2 API cases, options are only carried by `write options`
- SQL, `options` are only carried by `DataSourceV2Relation`

BTW, `SessionConfigSupport` only takes effect on `DataFrame` and `DataFrameV2` API, it is not considered in the `SQL` read/write path entirely in the current codebase.

### Why are the changes needed?

Correctly propagate SQL options to `WriteBuilder`, to complete the feature added in SPARK-49098, so that DSv2 implementations like Iceberg can benefit.

### Does this PR introduce _any_ user-facing change?

No, it's an unreleased feature.

### How was this patch tested?

UTs added by SPARK-36680 and SPARK-49098 are updated also to check SQL `options` are correctly propagated to the physical plan

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48822 from pan3793/SPARK-50286.

Authored-by: Cheng Pan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
pan3793 authored and cloud-fan committed Nov 25, 2024
1 parent da4bcb7 commit 976f887
Show file tree
Hide file tree
Showing 9 changed files with 406 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ abstract class InMemoryBaseTable(
TableCapability.TRUNCATE)

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryScanBuilder(schema)
new InMemoryScanBuilder(schema, options)
}

private def canEvaluate(filter: Filter): Boolean = {
Expand All @@ -309,16 +309,18 @@ abstract class InMemoryBaseTable(
}
}

class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder
with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
class InMemoryScanBuilder(
tableSchema: StructType,
options: CaseInsensitiveStringMap) extends ScanBuilder
with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
private var schema: StructType = tableSchema
private var postScanFilters: Array[Filter] = Array.empty
private var evaluableFilters: Array[Filter] = Array.empty
private var _pushedFilters: Array[Filter] = Array.empty

override def build: Scan = {
val scan = InMemoryBatchScan(
data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema)
data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, options)
if (evaluableFilters.nonEmpty) {
scan.filter(evaluableFilters)
}
Expand Down Expand Up @@ -442,7 +444,8 @@ abstract class InMemoryBaseTable(
case class InMemoryBatchScan(
var _data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
tableSchema: StructType,
options: CaseInsensitiveStringMap)
extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeFiltering {

override def filterAttributes(): Array[NamedReference] = {
Expand Down Expand Up @@ -474,17 +477,17 @@ abstract class InMemoryBaseTable(
}
}

abstract class InMemoryWriterBuilder() extends SupportsTruncate with SupportsDynamicOverwrite
with SupportsStreamingUpdateAsAppend {
abstract class InMemoryWriterBuilder(val info: LogicalWriteInfo)
extends SupportsTruncate with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend {

protected var writer: BatchWrite = Append
protected var streamingWriter: StreamingWrite = StreamingAppend
protected var writer: BatchWrite = new Append(info)
protected var streamingWriter: StreamingWrite = new StreamingAppend(info)

override def overwriteDynamicPartitions(): WriteBuilder = {
if (writer != Append) {
if (!writer.isInstanceOf[Append]) {
throw new IllegalArgumentException(s"Unsupported writer type: $writer")
}
writer = DynamicOverwrite
writer = new DynamicOverwrite(info)
streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions")
this
}
Expand Down Expand Up @@ -529,21 +532,21 @@ abstract class InMemoryBaseTable(
override def abort(messages: Array[WriterCommitMessage]): Unit = {}
}

protected object Append extends TestBatchWrite {
class Append(val info: LogicalWriteInfo) extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
withData(messages.map(_.asInstanceOf[BufferedRows]))
}
}

private object DynamicOverwrite extends TestBatchWrite {
class DynamicOverwrite(val info: LogicalWriteInfo) extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
val newData = messages.map(_.asInstanceOf[BufferedRows])
dataMap --= newData.flatMap(_.rows.map(getKey))
withData(newData)
}
}

protected object TruncateAndAppend extends TestBatchWrite {
class TruncateAndAppend(val info: LogicalWriteInfo) extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
dataMap.clear()
withData(messages.map(_.asInstanceOf[BufferedRows]))
Expand Down Expand Up @@ -572,15 +575,15 @@ abstract class InMemoryBaseTable(
s"${operation} isn't supported for streaming query.")
}

private object StreamingAppend extends TestStreamingWrite {
class StreamingAppend(val info: LogicalWriteInfo) extends TestStreamingWrite {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
dataMap.synchronized {
withData(messages.map(_.asInstanceOf[BufferedRows]))
}
}
}

protected object StreamingTruncateAndAppend extends TestStreamingWrite {
class StreamingTruncateAndAppend(val info: LogicalWriteInfo) extends TestStreamingWrite {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
dataMap.synchronized {
dataMap.clear()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class InMemoryRowLevelOperationTable(
}

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryScanBuilder(schema) {
new InMemoryScanBuilder(schema, options) {
override def build: Scan = {
val scan = super.build()
configuredScan = scan.asInstanceOf[InMemoryBatchScan]
Expand Down Expand Up @@ -115,7 +115,7 @@ class InMemoryRowLevelOperationTable(
override def rowId(): Array[NamedReference] = Array(PK_COLUMN_REF)

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryScanBuilder(schema)
new InMemoryScanBuilder(schema, options)
}

override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,23 @@ class InMemoryTable(
InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties))
InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)

new InMemoryWriterBuilderWithOverWrite()
new InMemoryWriterBuilderWithOverWrite(info)
}

private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder
with SupportsOverwrite {
class InMemoryWriterBuilderWithOverWrite(override val info: LogicalWriteInfo)
extends InMemoryWriterBuilder(info) with SupportsOverwrite {

override def truncate(): WriteBuilder = {
if (writer != Append) {
if (!writer.isInstanceOf[Append]) {
throw new IllegalArgumentException(s"Unsupported writer type: $writer")
}
writer = TruncateAndAppend
streamingWriter = StreamingTruncateAndAppend
writer = new TruncateAndAppend(info)
streamingWriter = new StreamingTruncateAndAppend(info)
this
}

override def overwrite(filters: Array[Filter]): WriteBuilder = {
if (writer != Append) {
if (!writer.isInstanceOf[Append]) {
throw new IllegalArgumentException(s"Unsupported writer type: $writer")
}
writer = new Overwrite(filters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,22 @@ class InMemoryTableWithV2Filter(
}

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryV2FilterScanBuilder(schema)
new InMemoryV2FilterScanBuilder(schema, options)
}

class InMemoryV2FilterScanBuilder(tableSchema: StructType)
extends InMemoryScanBuilder(tableSchema) {
class InMemoryV2FilterScanBuilder(
tableSchema: StructType,
options: CaseInsensitiveStringMap)
extends InMemoryScanBuilder(tableSchema, options) {
override def build: Scan = InMemoryV2FilterBatchScan(
data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema)
data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, options)
}

case class InMemoryV2FilterBatchScan(
var _data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
tableSchema: StructType,
options: CaseInsensitiveStringMap)
extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeV2Filtering {

override def filterAttributes(): Array[NamedReference] = {
Expand Down Expand Up @@ -93,21 +96,21 @@ class InMemoryTableWithV2Filter(
InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties))
InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)

new InMemoryWriterBuilderWithOverWrite()
new InMemoryWriterBuilderWithOverWrite(info)
}

private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder
with SupportsOverwriteV2 {
class InMemoryWriterBuilderWithOverWrite(override val info: LogicalWriteInfo)
extends InMemoryWriterBuilder(info) with SupportsOverwriteV2 {

override def truncate(): WriteBuilder = {
assert(writer == Append)
writer = TruncateAndAppend
streamingWriter = StreamingTruncateAndAppend
assert(writer.isInstanceOf[Append])
writer = new TruncateAndAppend(info)
streamingWriter = new StreamingTruncateAndAppend(info)
this
}

override def overwrite(predicates: Array[Predicate]): WriteBuilder = {
assert(writer == Append)
assert(writer.isInstanceOf[Append])
writer = new Overwrite(predicates)
streamingWriter = new StreamingNotSupportedOperation(
s"overwrite (${predicates.mkString("filters(", ", ", ")")})")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.v2

import java.util.{Optional, UUID}

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceData, WriteDelta}
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -44,7 +46,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {

override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case a @ AppendData(r: DataSourceV2Relation, query, options, _, None, _) =>
val writeBuilder = newWriteBuilder(r.table, options, query.schema)
val writeOptions = mergeOptions(options, r.options.asScala.toMap)
val writeBuilder = newWriteBuilder(r.table, writeOptions, query.schema)
val write = writeBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
a.copy(write = Some(write), query = newQuery)
Expand All @@ -61,7 +64,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
}.toArray

val table = r.table
val writeBuilder = newWriteBuilder(table, options, query.schema)
val writeOptions = mergeOptions(options, r.options.asScala.toMap)
val writeBuilder = newWriteBuilder(table, writeOptions, query.schema)
val write = writeBuilder match {
case builder: SupportsTruncate if isTruncate(predicates) =>
builder.truncate().build()
Expand All @@ -76,7 +80,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {

case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) =>
val table = r.table
val writeBuilder = newWriteBuilder(table, options, query.schema)
val writeOptions = mergeOptions(options, r.options.asScala.toMap)
val writeBuilder = newWriteBuilder(table, writeOptions, query.schema)
val write = writeBuilder match {
case builder: SupportsDynamicOverwrite =>
builder.overwriteDynamicPartitions().build()
Expand All @@ -87,31 +92,44 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
o.copy(write = Some(write), query = newQuery)

case WriteToMicroBatchDataSource(
relation, table, query, queryId, writeOptions, outputMode, Some(batchId)) =>

relationOpt, table, query, queryId, options, outputMode, Some(batchId)) =>
val writeOptions = mergeOptions(
options, relationOpt.map(r => r.options.asScala.toMap).getOrElse(Map.empty))
val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId)
val write = buildWriteForMicroBatch(table, writeBuilder, outputMode)
val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming)
val customMetrics = write.supportedCustomMetrics.toImmutableArraySeq
val funCatalogOpt = relation.flatMap(_.funCatalog)
val funCatalogOpt = relationOpt.flatMap(_.funCatalog)
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt)
WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics)
WriteToDataSourceV2(relationOpt, microBatchWrite, newQuery, customMetrics)

case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, _, None) =>
val rowSchema = DataTypeUtils.fromAttributes(rd.dataInput)
val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema)
val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap)
val writeBuilder = newWriteBuilder(r.table, writeOptions, rowSchema)
val write = writeBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog)
// project away any metadata columns that could be used for distribution and ordering
rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery))

case wd @ WriteDelta(r: DataSourceV2Relation, _, query, _, projections, None) =>
val deltaWriteBuilder = newDeltaWriteBuilder(r.table, Map.empty, projections)
val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap)
val deltaWriteBuilder = newDeltaWriteBuilder(r.table, writeOptions, projections)
val deltaWrite = deltaWriteBuilder.build()
val newQuery = DistributionAndOrderingUtils.prepareQuery(deltaWrite, query, r.funCatalog)
wd.copy(write = Some(deltaWrite), query = newQuery)
}

private def mergeOptions(
commandOptions: Map[String, String],
dsOptions: Map[String, String]): Map[String, String] = {
// for DataFrame API cases, same options are carried by both Command and DataSourceV2Relation
// for DataFrameV2 API cases, options are only carried by Command
// for SQL cases, options are only carried by DataSourceV2Relation
assert(commandOptions == dsOptions || commandOptions.isEmpty || dsOptions.isEmpty)
commandOptions ++ dsOptions
}

private def buildWriteForMicroBatch(
table: SupportsWrite,
writeBuilder: WriteBuilder,
Expand Down
10 changes: 5 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.scalatest.Assertions
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -449,12 +449,12 @@ object QueryTest extends Assertions {
}
}

def withPhysicalPlansCaptured(spark: SparkSession, thunk: => Unit): Seq[SparkPlan] = {
var capturedPlans = Seq.empty[SparkPlan]
def withQueryExecutionsCaptured(spark: SparkSession)(thunk: => Unit): Seq[QueryExecution] = {
var capturedQueryExecutions = Seq.empty[QueryExecution]

val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
capturedPlans = capturedPlans :+ qe.executedPlan
capturedQueryExecutions = capturedQueryExecutions :+ qe
}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}
Expand All @@ -468,7 +468,7 @@ object QueryTest extends Assertions {
spark.listenerManager.unregister(listener)
}

capturedPlans
capturedQueryExecutions
}
}

Expand Down
Loading

0 comments on commit 976f887

Please sign in to comment.