Skip to content

Commit

Permalink
[SPARK-50287][SQL] Merge options of table and relation when creating …
Browse files Browse the repository at this point in the history
…WriteBuilder in FileTable

### What changes were proposed in this pull request?

Merge `options` of table and relation when creating WriteBuilder in FileTable.

### Why are the changes needed?

Similar to SPARK-49519 which fixes the read path.

### Does this PR introduce _any_ user-facing change?

FileTable's options are accounted on the V2 write path now, but given the built-in file formats use V1 by default, it has no real effect.

### How was this patch tested?

UT is updated to cover the case.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48821 from pan3793/SPARK-50287.

Authored-by: Cheng Pan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
pan3793 authored and cloud-fan committed Nov 25, 2024
1 parent 976f887 commit 5dc65a5
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ case class AvroTable(
override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
AvroUtils.inferSchema(sparkSession, options.asScala.toMap, files)

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder =
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
new WriteBuilder {
override def build(): Write = AvroWrite(paths, formatName, supportsDataType, info)
override def build(): Write =
AvroWrite(paths, formatName, supportsDataType, mergedWriteInfo(info))
}
}

override def supportsDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.streaming.{FileStreamSink, MetadataLogFileIndex}
Expand Down Expand Up @@ -159,6 +160,19 @@ abstract class FileTable(
options.asCaseSensitiveMap().asScala
new CaseInsensitiveStringMap(finalOptions.asJava)
}

/**
* Merge the options of FileTable and the LogicalWriteInfo while respecting the
* keys of the options carried by LogicalWriteInfo.
*/
protected def mergedWriteInfo(writeInfo: LogicalWriteInfo): LogicalWriteInfo = {
LogicalWriteInfoImpl(
writeInfo.queryId(),
writeInfo.schema(),
mergedOptions(writeInfo.options()),
writeInfo.rowIdSchema(),
writeInfo.metadataSchema())
}
}

object FileTable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ trait FileWrite extends Write {

private val schema = info.schema()
private val queryId = info.queryId()
private val options = info.options()
val options = info.options()

override def description(): String = formatName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ case class CSVTable(
CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions)
}

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder =
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
new WriteBuilder {
override def build(): Write = CSVWrite(paths, formatName, supportsDataType, info)
override def build(): Write =
CSVWrite(paths, formatName, supportsDataType, mergedWriteInfo(info))
}
}

override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: AtomicType => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ case class JsonTable(
sparkSession, files, parsedOptions)
}

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder =
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
new WriteBuilder {
override def build(): Write = JsonWrite(paths, formatName, supportsDataType, info)
override def build(): Write =
JsonWrite(paths, formatName, supportsDataType, mergedWriteInfo(info))
}
}

override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: AtomicType => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ case class OrcTable(
override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
OrcUtils.inferSchema(sparkSession, files, options.asScala.toMap)

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder =
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
new WriteBuilder {
override def build(): Write = OrcWrite(paths, formatName, supportsDataType, info)
override def build(): Write =
OrcWrite(paths, formatName, supportsDataType, mergedWriteInfo(info))
}
}

override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: AtomicType => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ case class ParquetTable(
override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
ParquetUtils.inferSchema(sparkSession, options.asScala.toMap, files)

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder =
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
new WriteBuilder {
override def build(): Write = ParquetWrite(paths, formatName, supportsDataType, info)
override def build(): Write =
ParquetWrite(paths, formatName, supportsDataType, mergedWriteInfo(info))
}
}

override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: AtomicType => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ case class TextTable(
override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
Some(StructType(Array(StructField("value", StringType))))

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder =
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
new WriteBuilder {
override def build(): Write = TextWrite(paths, formatName, supportsDataType, info)
override def build(): Write =
TextWrite(paths, formatName, supportsDataType, mergedWriteInfo(info))
}
}

override def supportsDataType(dataType: DataType): Boolean = dataType == StringType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.FileStatus

import org.apache.spark.sql.{QueryTest, SparkSession}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, WriteBuilder}
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
Expand Down Expand Up @@ -96,8 +96,8 @@ class FileTableSuite extends QueryTest with SharedSparkSession {
}

allFileBasedDataSources.foreach { format =>
test(s"SPARK-49519: Merge options of table and relation when constructing FileScanBuilder" +
s" - $format") {
test("SPARK-49519, SPARK-50287: Merge options of table and relation when " +
s"constructing ScanBuilder and WriteBuilder in FileFormat - $format") {
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
val userSpecifiedSchema = StructType(Seq(StructField("c1", StringType)))

Expand All @@ -108,20 +108,29 @@ class FileTableSuite extends QueryTest with SharedSparkSession {
val table = provider.getTable(
userSpecifiedSchema,
Array.empty,
dsOptions.asCaseSensitiveMap())
dsOptions.asCaseSensitiveMap()).asInstanceOf[FileTable]
val tableOptions = new CaseInsensitiveStringMap(
Map("k2" -> "table_v2", "k3" -> "v3").asJava)
val mergedOptions = table.asInstanceOf[FileTable].newScanBuilder(tableOptions) match {

val mergedReadOptions = table.newScanBuilder(tableOptions) match {
case csv: CSVScanBuilder => csv.options
case json: JsonScanBuilder => json.options
case orc: OrcScanBuilder => orc.options
case parquet: ParquetScanBuilder => parquet.options
case text: TextScanBuilder => text.options
}
assert(mergedOptions.size() == 3)
assert("v1".equals(mergedOptions.get("k1")))
assert("table_v2".equals(mergedOptions.get("k2")))
assert("v3".equals(mergedOptions.get("k3")))
assert(mergedReadOptions.size === 3)
assert(mergedReadOptions.get("k1") === "v1")
assert(mergedReadOptions.get("k2") === "table_v2")
assert(mergedReadOptions.get("k3") === "v3")

val writeInfo = LogicalWriteInfoImpl("query-id", userSpecifiedSchema, tableOptions)
val mergedWriteOptions = table.newWriteBuilder(writeInfo).build()
.asInstanceOf[FileWrite].options
assert(mergedWriteOptions.size === 3)
assert(mergedWriteOptions.get("k1") === "v1")
assert(mergedWriteOptions.get("k2") === "table_v2")
assert(mergedWriteOptions.get("k3") === "v3")
case _ =>
throw new IllegalArgumentException(s"Failed to get table provider for $format")
}
Expand Down

0 comments on commit 5dc65a5

Please sign in to comment.