From afb8d301f321f5daab75927f65f994dc64f4a95d Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Tue, 12 Nov 2024 22:20:47 +0100 Subject: [PATCH 01/17] [SPARK-50285] Metrics for commits to StagedTable instances --- .../catalog/StagedTableWithCommitMetrics.java | 38 +++++ .../catalog/StagingInMemoryTableCatalog.scala | 2 +- .../datasources/v2/ReplaceTableExec.scala | 23 ++- .../v2/WriteToDataSourceV2Exec.scala | 19 ++- .../connector/DataSourceV2MetricsSuite.scala | 133 ++++++++++++++++++ 5 files changed, 210 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java new file mode 100644 index 0000000000000..230312644e011 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import java.util.Map; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.types.DataType; + +/** + * An extension of the {@link StagedTable} interface that provides metrics after a commit. + */ +@Evolving +public interface StagedTableWithCommitMetrics extends StagedTable { + + /** + * Returns a map of commit metric values after a successful commit. Throws otherwise. + * + * @return a {@link Map} of commit metric values. The keys are the commit names, the values + * are the metrics values of type Long. + */ + Map getCommitMetrics() throws AssertionError; +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala index f3c7bc98cec09..2a207901b83f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala @@ -78,7 +78,7 @@ class StagingInMemoryTableCatalog extends InMemoryTableCatalog with StagingTable maybeSimulateFailedTableCreation(properties) } - private abstract class TestStagedTable( + protected abstract class TestStagedTable( ident: Identifier, delegateTable: InMemoryTable) extends StagedTable with SupportsWrite with SupportsRead { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 104d8a706efb7..f2150cfca6394 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -23,9 +23,11 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.TableSpec -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagedTableWithCommitMetrics, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.Utils case class ReplaceTableExec( @@ -65,6 +67,11 @@ case class AtomicReplaceTableExec( val tableProperties = CatalogV2Util.convertTableProperties(tableSpec) + override val metrics: Map[String, SQLMetric] = Map( + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "written output")) + override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(identifier)) { val table = catalog.loadTable(identifier) @@ -92,7 +99,19 @@ case class AtomicReplaceTableExec( private def commitOrAbortStagedChanges(staged: StagedTable): Unit = { Utils.tryWithSafeFinallyAndFailureCallbacks({ - staged.commitStagedChanges() + staged match { + case st: StagedTableWithCommitMetrics => + st.commitStagedChanges() + + st.getCommitMetrics.forEach { + case (name: String, value: java.lang.Long) => + metrics.get(name).foreach(_.set(value)) + } + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + case st: StagedTable => st.commitStagedChanges() + } })(catchBlock = { staged.abortStagedChanges() }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 5885ec0afadcd..756ecf7095d07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -28,12 +28,12 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, TableSpec, UnaryNode} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, UPDATE_OPERATION} -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagedTableWithCommitMetrics, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryExecNode} import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{LongAccumulator, Utils} @@ -599,6 +599,11 @@ case class DeltaWithMetadataWritingSparkTask( private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { override def output: Seq[Attribute] = Nil + override val metrics: Map[String, SQLMetric] = Map( + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "written output")) + protected def getV2Columns(schema: StructType, forceNullable: Boolean): Array[Column] = { val rawSchema = CharVarcharUtils.getRawSchema(removeInternalMetadata(schema), conf) val tableSchema = if (forceNullable) rawSchema.asNullable else rawSchema @@ -618,6 +623,16 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { qe.assertCommandExecuted() table match { + case st: StagedTableWithCommitMetrics => + st.commitStagedChanges() + + st.getCommitMetrics.forEach { + case (name: String, value: java.lang.Long) => + metrics.get(name).foreach(_.set(value)) + } + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) case st: StagedTable => st.commitStagedChanges() case _ => } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala new file mode 100644 index 0000000000000..95200825cf02b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util + +import scala.jdk.CollectionConverters.MapHasAsJava + +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryTable, InMemoryTableCatalog, StagedTable, StagedTableWithCommitMetrics, StagingInMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.execution.CommandResultExec + +class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog { + + private class TestStagedTableWithMetric( + ident: Identifier, + delegateTable: InMemoryTable + ) extends TestStagedTable(ident, delegateTable) with StagedTableWithCommitMetrics { + + override def commitStagedChanges(): Unit = { + tables.put(ident, delegateTable) + } + + override def getCommitMetrics: util.Map[String, java.lang.Long] = { + StagingInMemoryTableCatalogWithMetrics.testMetrics + .asInstanceOf[Map[String, java.lang.Long]].asJava + } + } + + override def stageCreate( + ident: Identifier, + columns: Array[Column], + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = { + new TestStagedTableWithMetric( + ident, + new InMemoryTable(s"$name.${ident.quoted}", + CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties)) + } + + override def stageReplace( + ident: Identifier, + columns: Array[Column], + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + stageCreate(ident, columns, partitions, properties) + + override def stageCreateOrReplace( + ident: Identifier, + columns: Array[Column], + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + stageCreate(ident, columns, partitions, properties) +} + +object StagingInMemoryTableCatalogWithMetrics { + + val testMetrics: Map[String, Long] = Map( + "numFiles" -> 1337, + "numOutputRows" -> 1338, + "numOutputBytes" -> 1339, + ) +} + +class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { + + private val testCatalog = "test_catalog" + private val atomicTestCatalog = "atomic_test_catalog" + private val nonExistingTable = "non_existing_table" + private val existingTable = "existing_table" + + private def commands(catalogName: String) = Seq( + s"CREATE TABLE $catalogName.$nonExistingTable AS SELECT * FROM $existingTable", + s"CREATE OR REPLACE TABLE $catalogName.$nonExistingTable AS SELECT * FROM $existingTable", + s"REPLACE TABLE $catalogName.$existingTable AS SELECT * FROM $existingTable", + s"REPLACE TABLE $catalogName.$existingTable (id bigint, data string)", + ) + + private def catalogCommitMetricsTest( + testName: String, catalogName: String)(testFunction: String => Unit): Unit = { + commands(catalogName).foreach { command => + test(s"$testName - $command") { + registerCatalog(testCatalog, classOf[InMemoryTableCatalog]) + registerCatalog(atomicTestCatalog, classOf[StagingInMemoryTableCatalogWithMetrics]) + withTable(existingTable, s"$catalogName.$existingTable") { + sql(s"CREATE TABLE $existingTable (id bigint, data string)") + sql(s"CREATE TABLE $catalogName.$existingTable (id bigint, data string)") + + testFunction(command) + } + } + } + } + + catalogCommitMetricsTest( + "Plan metrics are 0 if the catalog does not support them", testCatalog) { command => + val df = sql(command) + val metrics = df.queryExecution.executedPlan match { + case c: CommandResultExec => c.commandPhysicalPlan.metrics + } + + assert(metrics.forall(_._2.value == 0)) + } + + catalogCommitMetricsTest( + "Plan metrics values are the values from the catalog", atomicTestCatalog) { command => + val df = sql(command) + val metrics = df.queryExecution.executedPlan match { + case c: CommandResultExec => c.commandPhysicalPlan.metrics + } + + assert(metrics.size === StagingInMemoryTableCatalogWithMetrics.testMetrics.size) + StagingInMemoryTableCatalogWithMetrics.testMetrics.foreach { case (k, v) => + assert(metrics(k).value === v) + } + } +} From 134b03bbaa198b30652bfbde9dd26eb219fa3502 Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Wed, 13 Nov 2024 16:52:16 +0100 Subject: [PATCH 02/17] Fix scalastyle. Only add metrics if catalog supports them --- .../catalog/StagingTableCatalog.java | 6 ++++++ .../datasources/v2/ReplaceTableExec.scala | 11 ++++++---- .../v2/WriteToDataSourceV2Exec.scala | 21 +++++++++++++++---- .../connector/DataSourceV2MetricsSuite.scala | 12 +++++------ 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java index eead1ade40791..ac103d43ade8e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java @@ -200,4 +200,10 @@ default StagedTable stageCreateOrReplace( return stageCreateOrReplace( ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties); } + + /** + * @return True if the catalog returns instances of type {@link StagedTableWithCommitMetrics} + * which support to retrieve commit metrics after a successful commit. + */ + default boolean supportsCommitMetrics() { return false; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index f2150cfca6394..5743b36e9f280 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -67,10 +67,13 @@ case class AtomicReplaceTableExec( val tableProperties = CatalogV2Util.convertTableProperties(tableSpec) - override val metrics: Map[String, SQLMetric] = Map( - "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "written output")) + override val metrics: Map[String, SQLMetric] = if (catalog.supportsCommitMetrics()) { + Map("numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "written output")) + } else { + Map.empty + } override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(identifier)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 756ecf7095d07..37e5bf66f7844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -75,6 +75,8 @@ case class CreateTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) + override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) + override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { if (ifNotExists) { @@ -110,6 +112,8 @@ case class AtomicCreateTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) + override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) + override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { if (ifNotExists) { @@ -148,6 +152,8 @@ case class ReplaceTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) + override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) + override protected def run(): Seq[InternalRow] = { // Note that this operation is potentially unsafe, but these are the strict semantics of // RTAS if the catalog does not support atomic operations. @@ -197,6 +203,8 @@ case class AtomicReplaceTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) + override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) + override protected def run(): Seq[InternalRow] = { val columns = getV2Columns(query.schema, catalog.useNullableQuerySchema) if (catalog.tableExists(ident)) { @@ -599,10 +607,15 @@ case class DeltaWithMetadataWritingSparkTask( private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { override def output: Seq[Attribute] = Nil - override val metrics: Map[String, SQLMetric] = Map( - "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "written output")) + protected def commitMetrics(tableCatalog: TableCatalog): Map[String, SQLMetric] = { + tableCatalog match { + case st: StagingTableCatalog if st.supportsCommitMetrics() => + Map("numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "written output")) + case _ => Map.empty + } + } protected def getV2Columns(schema: StructType, forceNullable: Boolean): Array[Column] = { val rawSchema = CharVarcharUtils.getRawSchema(removeInternalMetadata(schema), conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala index 95200825cf02b..29d4696958844 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.execution.CommandResultExec class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog { + override def supportsCommitMetrics(): Boolean = true + private class TestStagedTableWithMetric( ident: Identifier, delegateTable: InMemoryTable @@ -74,8 +76,7 @@ object StagingInMemoryTableCatalogWithMetrics { val testMetrics: Map[String, Long] = Map( "numFiles" -> 1337, "numOutputRows" -> 1338, - "numOutputBytes" -> 1339, - ) + "numOutputBytes" -> 1339) } class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { @@ -89,8 +90,7 @@ class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { s"CREATE TABLE $catalogName.$nonExistingTable AS SELECT * FROM $existingTable", s"CREATE OR REPLACE TABLE $catalogName.$nonExistingTable AS SELECT * FROM $existingTable", s"REPLACE TABLE $catalogName.$existingTable AS SELECT * FROM $existingTable", - s"REPLACE TABLE $catalogName.$existingTable (id bigint, data string)", - ) + s"REPLACE TABLE $catalogName.$existingTable (id bigint, data string)") private def catalogCommitMetricsTest( testName: String, catalogName: String)(testFunction: String => Unit): Unit = { @@ -109,13 +109,13 @@ class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { } catalogCommitMetricsTest( - "Plan metrics are 0 if the catalog does not support them", testCatalog) { command => + "No metrics in the plan if the catalog does not support them", testCatalog) { command => val df = sql(command) val metrics = df.queryExecution.executedPlan match { case c: CommandResultExec => c.commandPhysicalPlan.metrics } - assert(metrics.forall(_._2.value == 0)) + assert(metrics.isEmpty) } catalogCommitMetricsTest( From ab92ca9378b33c45b334619caf9f054863756c53 Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Wed, 13 Nov 2024 21:17:34 +0100 Subject: [PATCH 03/17] Remove unused import --- .../sql/connector/catalog/StagedTableWithCommitMetrics.java | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java index 230312644e011..54f0632219cca 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java @@ -20,7 +20,6 @@ import java.util.Map; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.types.DataType; /** * An extension of the {@link StagedTable} interface that provides metrics after a commit. From 26965c45fae5a4427083fa75fabb1446919247e3 Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Thu, 14 Nov 2024 14:13:52 +0100 Subject: [PATCH 04/17] Add a supportedCommitMetrics to the catalog --- .../catalog/StagedTableWithCommitMetrics.java | 13 +++--- .../catalog/StagingTableCatalog.java | 7 +-- .../datasources/v2/ReplaceTableExec.scala | 16 +++---- .../v2/WriteToDataSourceV2Exec.scala | 13 +++--- .../connector/DataSourceV2MetricsSuite.scala | 43 ++++++++++++++----- 5 files changed, 53 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java index 54f0632219cca..825bbef4a27c0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java @@ -17,21 +17,18 @@ package org.apache.spark.sql.connector.catalog; -import java.util.Map; - import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; /** - * An extension of the {@link StagedTable} interface that provides metrics after a commit. + * An extension of the {@link StagedTable} interface that allows to retrieve metrics after a commit. */ @Evolving public interface StagedTableWithCommitMetrics extends StagedTable { /** - * Returns a map of commit metric values after a successful commit. Throws otherwise. - * - * @return a {@link Map} of commit metric values. The keys are the commit names, the values - * are the metrics values of type Long. + * @return a {@link java.lang.Iterable} of commit metric values. Throws if the table has not + * been committed yet. */ - Map getCommitMetrics() throws AssertionError; + Iterable getCommitMetrics() throws AssertionError; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java index ac103d43ade8e..9396e52dec4c1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java @@ -17,10 +17,12 @@ package org.apache.spark.sql.connector.catalog; +import java.util.Collections; import java.util.Map; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.metric.CustomMetric; import org.apache.spark.sql.connector.write.LogicalWriteInfo; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; @@ -202,8 +204,7 @@ default StagedTable stageCreateOrReplace( } /** - * @return True if the catalog returns instances of type {@link StagedTableWithCommitMetrics} - * which support to retrieve commit metrics after a successful commit. + * @return A {@link java.lang.Iterable} of commit metrics that are supported by the catalog. */ - default boolean supportsCommitMetrics() { return false; } + default Iterable supportedCommitMetrics() { return Collections.emptyList(); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 5743b36e9f280..68b4e0c1cb34b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -67,13 +67,10 @@ case class AtomicReplaceTableExec( val tableProperties = CatalogV2Util.convertTableProperties(tableSpec) - override val metrics: Map[String, SQLMetric] = if (catalog.supportsCommitMetrics()) { - Map("numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "written output")) - } else { - Map.empty - } + override val metrics: Map[String, SQLMetric] = + catalog.supportedCommitMetrics().asScala.map { metric => + metric.name() -> SQLMetrics.createMetric(sparkContext, metric.name()) + }.toMap override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(identifier)) { @@ -106,9 +103,8 @@ case class AtomicReplaceTableExec( case st: StagedTableWithCommitMetrics => st.commitStagedChanges() - st.getCommitMetrics.forEach { - case (name: String, value: java.lang.Long) => - metrics.get(name).foreach(_.set(value)) + st.getCommitMetrics.forEach { taskMetric => + metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) } val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 37e5bf66f7844..9bcf9178e3875 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -609,10 +609,10 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { protected def commitMetrics(tableCatalog: TableCatalog): Map[String, SQLMetric] = { tableCatalog match { - case st: StagingTableCatalog if st.supportsCommitMetrics() => - Map("numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "written output")) + case st: StagingTableCatalog => + st.supportedCommitMetrics().asScala.map { + metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) + }.toMap case _ => Map.empty } } @@ -639,9 +639,8 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { case st: StagedTableWithCommitMetrics => st.commitStagedChanges() - st.getCommitMetrics.forEach { - case (name: String, value: java.lang.Long) => - metrics.get(name).foreach(_.set(value)) + st.getCommitMetrics.forEach { taskMetric => + metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) } val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala index 29d4696958844..aded4c004dbc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala @@ -19,16 +19,29 @@ package org.apache.spark.sql.connector import java.util -import scala.jdk.CollectionConverters.MapHasAsJava +import scala.jdk.CollectionConverters.IterableHasAsJava import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryTable, InMemoryTableCatalog, StagedTable, StagedTableWithCommitMetrics, StagingInMemoryTableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric} import org.apache.spark.sql.execution.CommandResultExec class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog { - override def supportsCommitMetrics(): Boolean = true + override def supportedCommitMetrics(): java.lang.Iterable[CustomMetric] = java.util.List.of( + new CustomSumMetric { + override def name(): String = "numFiles" + override def description(): String = "number of written files" + }, + new CustomSumMetric { + override def name(): String = "numOutputRows" + override def description(): String = "number of output rows" + }, + new CustomSumMetric { + override def name(): String = "numOutputBytes" + override def description(): String = "written output" + }) private class TestStagedTableWithMetric( ident: Identifier, @@ -39,9 +52,8 @@ class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog tables.put(ident, delegateTable) } - override def getCommitMetrics: util.Map[String, java.lang.Long] = { - StagingInMemoryTableCatalogWithMetrics.testMetrics - .asInstanceOf[Map[String, java.lang.Long]].asJava + override def getCommitMetrics: java.lang.Iterable[CustomTaskMetric] = { + StagingInMemoryTableCatalogWithMetrics.testMetrics.asJava } } @@ -73,10 +85,19 @@ class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog object StagingInMemoryTableCatalogWithMetrics { - val testMetrics: Map[String, Long] = Map( - "numFiles" -> 1337, - "numOutputRows" -> 1338, - "numOutputBytes" -> 1339) + val testMetrics: Seq[CustomTaskMetric] = Seq( + new CustomTaskMetric { + override def name(): String = "numFiles" + override def value(): Long = 1337 + }, + new CustomTaskMetric { + override def name(): String = "numOutputRows" + override def value(): Long = 1338 + }, + new CustomTaskMetric { + override def name(): String = "numOutputBytes" + override def value(): Long = 1339 + }) } class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { @@ -126,8 +147,8 @@ class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { } assert(metrics.size === StagingInMemoryTableCatalogWithMetrics.testMetrics.size) - StagingInMemoryTableCatalogWithMetrics.testMetrics.foreach { case (k, v) => - assert(metrics(k).value === v) + StagingInMemoryTableCatalogWithMetrics.testMetrics.foreach { case customTaskMetric => + assert(metrics(customTaskMetric.name()).value === customTaskMetric.value()) } } } From 11911d86ab5f2bb7143b7a0e0aae0dbf2151f983 Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Thu, 14 Nov 2024 14:20:30 +0100 Subject: [PATCH 05/17] Use createV2CustomMetric in ReplaceTableExec --- .../spark/sql/execution/datasources/v2/ReplaceTableExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 68b4e0c1cb34b..e82380ecc0384 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -69,7 +69,7 @@ case class AtomicReplaceTableExec( override val metrics: Map[String, SQLMetric] = catalog.supportedCommitMetrics().asScala.map { metric => - metric.name() -> SQLMetrics.createMetric(sparkContext, metric.name()) + metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) }.toMap override protected def run(): Seq[InternalRow] = { From 58d3efc74599b3f314ad853c10522803528cb4b9 Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Thu, 14 Nov 2024 20:09:08 +0100 Subject: [PATCH 06/17] Add test cases where the operations are triggered from DataFrame operations --- .../connector/DataSourceV2MetricsSuite.scala | 79 +++++++++++++------ 1 file changed, 57 insertions(+), 22 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala index aded4c004dbc6..452c9cda2013c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala @@ -25,7 +25,9 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric} -import org.apache.spark.sql.execution.CommandResultExec +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec, AtomicReplaceTableExec, CreateTableAsSelectExec, ReplaceTableAsSelectExec, ReplaceTableExec} +import org.apache.spark.sql.util.QueryExecutionListener class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog { @@ -107,15 +109,55 @@ class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { private val nonExistingTable = "non_existing_table" private val existingTable = "existing_table" - private def commands(catalogName: String) = Seq( - s"CREATE TABLE $catalogName.$nonExistingTable AS SELECT * FROM $existingTable", - s"CREATE OR REPLACE TABLE $catalogName.$nonExistingTable AS SELECT * FROM $existingTable", - s"REPLACE TABLE $catalogName.$existingTable AS SELECT * FROM $existingTable", - s"REPLACE TABLE $catalogName.$existingTable (id bigint, data string)") + private def captureExecutedPlan(command: => Unit): SparkPlan = { + var commandPlan: Option[SparkPlan] = None + var otherPlans = Seq.empty[SparkPlan] + + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + qe.executedPlan match { + case _: CreateTableAsSelectExec | _: AtomicCreateTableAsSelectExec + | _: ReplaceTableAsSelectExec | _: AtomicReplaceTableAsSelectExec + | _: ReplaceTableExec | _: AtomicReplaceTableExec => + assert(commandPlan.isEmpty) + commandPlan = Some(qe.executedPlan) + case _ => + otherPlans :+= qe.executedPlan + } + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + + spark.listenerManager.register(listener) + + command + + sparkContext.listenerBus.waitUntilEmpty() + + assert(commandPlan.nonEmpty, s"No command plan found, but saw $otherPlans") + commandPlan.get + } + + private def commands: Seq[String => Unit] = Seq( + { catalogName => + sql(s"CREATE TABLE $catalogName.$nonExistingTable AS SELECT * FROM $existingTable") }, + { catalogName => + spark.table(existingTable).write.saveAsTable(s"$catalogName.$nonExistingTable") }, + { catalogName => + sql(s"CREATE OR REPLACE TABLE $catalogName.$nonExistingTable " + + s"AS SELECT * FROM $existingTable") }, + { catalogName => + sql(s"REPLACE TABLE $catalogName.$existingTable AS SELECT * FROM $existingTable") }, + { catalogName => + spark.table(existingTable) + .write.mode("overwrite").saveAsTable(s"$catalogName.$existingTable") }, + { catalogName => + sql(s"REPLACE TABLE $catalogName.$existingTable (id bigint, data string)") }) private def catalogCommitMetricsTest( - testName: String, catalogName: String)(testFunction: String => Unit): Unit = { - commands(catalogName).foreach { command => + testName: String, catalogName: String)(testFunction: SparkPlan => Unit): Unit = { + commands.foreach { command => test(s"$testName - $command") { registerCatalog(testCatalog, classOf[InMemoryTableCatalog]) registerCatalog(atomicTestCatalog, classOf[StagingInMemoryTableCatalogWithMetrics]) @@ -123,32 +165,25 @@ class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { sql(s"CREATE TABLE $existingTable (id bigint, data string)") sql(s"CREATE TABLE $catalogName.$existingTable (id bigint, data string)") - testFunction(command) + testFunction(captureExecutedPlan(command(catalogName))) } } } } catalogCommitMetricsTest( - "No metrics in the plan if the catalog does not support them", testCatalog) { command => - val df = sql(command) - val metrics = df.queryExecution.executedPlan match { - case c: CommandResultExec => c.commandPhysicalPlan.metrics - } + "No metrics in the plan if the catalog does not support them", testCatalog) { sparkPlan => + val metrics = sparkPlan.metrics assert(metrics.isEmpty) } catalogCommitMetricsTest( - "Plan metrics values are the values from the catalog", atomicTestCatalog) { command => - val df = sql(command) - val metrics = df.queryExecution.executedPlan match { - case c: CommandResultExec => c.commandPhysicalPlan.metrics - } + "Plan metrics values are the values from the catalog", atomicTestCatalog) { sparkPlan => + val metrics = sparkPlan.metrics assert(metrics.size === StagingInMemoryTableCatalogWithMetrics.testMetrics.size) - StagingInMemoryTableCatalogWithMetrics.testMetrics.foreach { case customTaskMetric => - assert(metrics(customTaskMetric.name()).value === customTaskMetric.value()) - } + StagingInMemoryTableCatalogWithMetrics.testMetrics.foreach(customTaskMetric => + assert(metrics(customTaskMetric.name()).value === customTaskMetric.value())) } } From c7b192f63e25a9c03baf05255344565fe4bd9c65 Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Sat, 16 Nov 2024 22:55:56 +0100 Subject: [PATCH 07/17] Introduce withPhysicalPlansCaptured and use case instead of anonymous classes --- .../org/apache/spark/sql/QueryTest.scala | 25 ++++++- .../connector/DataSourceV2MetricsSuite.scala | 73 ++++++------------- 2 files changed, 45 insertions(+), 53 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f5ba655e3e85f..30180d48da71a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -27,8 +27,9 @@ import org.scalatest.Assertions import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ @@ -447,6 +448,28 @@ object QueryTest extends Assertions { case None => } } + + def withPhysicalPlansCaptured(spark: SparkSession, thunk: => Unit): Seq[SparkPlan] = { + var capturedPlans = Seq.empty[SparkPlan] + + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + capturedPlans = capturedPlans :+ qe.executedPlan + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + + spark.sparkContext.listenerBus.waitUntilEmpty(15000) + spark.listenerManager.register(listener) + try { + thunk + spark.sparkContext.listenerBus.waitUntilEmpty(15000) + } finally { + spark.listenerManager.unregister(listener) + } + + capturedPlans + } } class QueryTestSuite extends QueryTest with test.SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala index 452c9cda2013c..3bd2b78b70e45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala @@ -21,6 +21,7 @@ import java.util import scala.jdk.CollectionConverters.IterableHasAsJava +import org.apache.spark.sql.QueryTest.withPhysicalPlansCaptured import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryTable, InMemoryTableCatalog, StagedTable, StagedTableWithCommitMetrics, StagingInMemoryTableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper import org.apache.spark.sql.connector.expressions.Transform @@ -31,19 +32,12 @@ import org.apache.spark.sql.util.QueryExecutionListener class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog { + case class TestSupportedCommitMetric(name: String, description: String) extends CustomSumMetric + override def supportedCommitMetrics(): java.lang.Iterable[CustomMetric] = java.util.List.of( - new CustomSumMetric { - override def name(): String = "numFiles" - override def description(): String = "number of written files" - }, - new CustomSumMetric { - override def name(): String = "numOutputRows" - override def description(): String = "number of output rows" - }, - new CustomSumMetric { - override def name(): String = "numOutputBytes" - override def description(): String = "written output" - }) + TestSupportedCommitMetric("numFiles", "number of written files"), + TestSupportedCommitMetric("numOutputRows", "number of output rows"), + TestSupportedCommitMetric("numOutputBytes", "written output")) private class TestStagedTableWithMetric( ident: Identifier, @@ -87,19 +81,12 @@ class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog object StagingInMemoryTableCatalogWithMetrics { + case class TestCustomTaskMetric(name: String, value: Long) extends CustomTaskMetric + val testMetrics: Seq[CustomTaskMetric] = Seq( - new CustomTaskMetric { - override def name(): String = "numFiles" - override def value(): Long = 1337 - }, - new CustomTaskMetric { - override def name(): String = "numOutputRows" - override def value(): Long = 1338 - }, - new CustomTaskMetric { - override def name(): String = "numOutputBytes" - override def value(): Long = 1339 - }) + TestCustomTaskMetric("numFiles", 1337), + TestCustomTaskMetric("numOutputRows", 1338), + TestCustomTaskMetric("numOutputBytes", 1339)) } class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { @@ -109,34 +96,16 @@ class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { private val nonExistingTable = "non_existing_table" private val existingTable = "existing_table" - private def captureExecutedPlan(command: => Unit): SparkPlan = { - var commandPlan: Option[SparkPlan] = None - var otherPlans = Seq.empty[SparkPlan] - - val listener = new QueryExecutionListener { - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - qe.executedPlan match { - case _: CreateTableAsSelectExec | _: AtomicCreateTableAsSelectExec - | _: ReplaceTableAsSelectExec | _: AtomicReplaceTableAsSelectExec - | _: ReplaceTableExec | _: AtomicReplaceTableExec => - assert(commandPlan.isEmpty) - commandPlan = Some(qe.executedPlan) - case _ => - otherPlans :+= qe.executedPlan - } - } - - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + private def captureStagedTableWrite(thunk: => Unit): SparkPlan = { + val physicalPlans = withPhysicalPlansCaptured(spark, thunk) + val stagedTableWrites = physicalPlans.filter { + case _: AtomicCreateTableAsSelectExec | _: CreateTableAsSelectExec | + _: AtomicReplaceTableAsSelectExec | _: ReplaceTableAsSelectExec | + _: AtomicReplaceTableExec | _: ReplaceTableExec => true + case _ => false } - - spark.listenerManager.register(listener) - - command - - sparkContext.listenerBus.waitUntilEmpty() - - assert(commandPlan.nonEmpty, s"No command plan found, but saw $otherPlans") - commandPlan.get + assert(stagedTableWrites.size === 1) + stagedTableWrites.head } private def commands: Seq[String => Unit] = Seq( @@ -165,7 +134,7 @@ class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { sql(s"CREATE TABLE $existingTable (id bigint, data string)") sql(s"CREATE TABLE $catalogName.$existingTable (id bigint, data string)") - testFunction(captureExecutedPlan(command(catalogName))) + testFunction(captureStagedTableWrite(command(catalogName))) } } } From 87ad166080fc6ce092ceac489baa1d7b2387f5ed Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Sat, 16 Nov 2024 23:10:05 +0100 Subject: [PATCH 08/17] Align interface with Write.java, add commit assert --- .../catalog/StagedTableWithCommitMetrics.java | 5 ++--- .../datasources/v2/ReplaceTableExec.scala | 2 +- .../v2/WriteToDataSourceV2Exec.scala | 2 +- .../connector/DataSourceV2MetricsSuite.scala | 17 +++++++++-------- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java index 825bbef4a27c0..1b9e8daafa685 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java @@ -27,8 +27,7 @@ public interface StagedTableWithCommitMetrics extends StagedTable { /** - * @return a {@link java.lang.Iterable} of commit metric values. Throws if the table has not - * been committed yet. + * @return an Array of commit metric values. Throws if the table has not been committed yet. */ - Iterable getCommitMetrics() throws AssertionError; + CustomTaskMetric[] reportDriverMetrics() throws AssertionError; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index e82380ecc0384..624a2869ac64f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -103,7 +103,7 @@ case class AtomicReplaceTableExec( case st: StagedTableWithCommitMetrics => st.commitStagedChanges() - st.getCommitMetrics.forEach { taskMetric => + for (taskMetric <- st.reportDriverMetrics) { metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 9bcf9178e3875..737dfad790c9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -639,7 +639,7 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { case st: StagedTableWithCommitMetrics => st.commitStagedChanges() - st.getCommitMetrics.forEach { taskMetric => + for (taskMetric <- st.reportDriverMetrics) { metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala index 3bd2b78b70e45..c169d62ed54ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala @@ -19,16 +19,13 @@ package org.apache.spark.sql.connector import java.util -import scala.jdk.CollectionConverters.IterableHasAsJava - import org.apache.spark.sql.QueryTest.withPhysicalPlansCaptured import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryTable, InMemoryTableCatalog, StagedTable, StagedTableWithCommitMetrics, StagingInMemoryTableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric} -import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec, AtomicReplaceTableExec, CreateTableAsSelectExec, ReplaceTableAsSelectExec, ReplaceTableExec} -import org.apache.spark.sql.util.QueryExecutionListener class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog { @@ -44,12 +41,16 @@ class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog delegateTable: InMemoryTable ) extends TestStagedTable(ident, delegateTable) with StagedTableWithCommitMetrics { + private var stagedChangesCommitted = false + override def commitStagedChanges(): Unit = { tables.put(ident, delegateTable) + stagedChangesCommitted = true } - override def getCommitMetrics: java.lang.Iterable[CustomTaskMetric] = { - StagingInMemoryTableCatalogWithMetrics.testMetrics.asJava + override def reportDriverMetrics: Array[CustomTaskMetric] = { + assert(stagedChangesCommitted) + StagingInMemoryTableCatalogWithMetrics.testMetrics } } @@ -83,7 +84,7 @@ object StagingInMemoryTableCatalogWithMetrics { case class TestCustomTaskMetric(name: String, value: Long) extends CustomTaskMetric - val testMetrics: Seq[CustomTaskMetric] = Seq( + val testMetrics: Array[CustomTaskMetric] = Array( TestCustomTaskMetric("numFiles", 1337), TestCustomTaskMetric("numOutputRows", 1338), TestCustomTaskMetric("numOutputBytes", 1339)) @@ -151,7 +152,7 @@ class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { "Plan metrics values are the values from the catalog", atomicTestCatalog) { sparkPlan => val metrics = sparkPlan.metrics - assert(metrics.size === StagingInMemoryTableCatalogWithMetrics.testMetrics.size) + assert(metrics.size === StagingInMemoryTableCatalogWithMetrics.testMetrics.length) StagingInMemoryTableCatalogWithMetrics.testMetrics.foreach(customTaskMetric => assert(metrics(customTaskMetric.name()).value === customTaskMetric.value())) } From 43bd2aecebc3239834c895bb1bacc291efabb13c Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Sat, 16 Nov 2024 23:26:46 +0100 Subject: [PATCH 09/17] Align the interfaces with the interfaces in Write --- .../catalog/StagedTableWithCommitMetrics.java | 6 ++++++ .../sql/connector/catalog/StagingTableCatalog.java | 11 ++++++++--- .../execution/datasources/v2/ReplaceTableExec.scala | 2 +- .../datasources/v2/WriteToDataSourceV2Exec.scala | 2 +- .../sql/connector/DataSourceV2MetricsSuite.scala | 2 +- 5 files changed, 17 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java index 1b9e8daafa685..47e76ae88c904 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java @@ -19,6 +19,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.metric.CustomTaskMetric; +import org.apache.spark.sql.connector.write.Write; /** * An extension of the {@link StagedTable} interface that allows to retrieve metrics after a commit. @@ -27,6 +28,11 @@ public interface StagedTableWithCommitMetrics extends StagedTable { /** + * Retrieve driver metrics after a commit. This is analogous + * to {@link Write#reportDriverMetrics()}. Note that these metrics must be included in the + * supported custom metrics reported by `supportedCustomMetrics` of the + * {@link StagingTableCatalog} that returned the staged table. + * * @return an Array of commit metric values. Throws if the table has not been committed yet. */ CustomTaskMetric[] reportDriverMetrics() throws AssertionError; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java index 9396e52dec4c1..39f7da6d05069 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.connector.catalog; -import java.util.Collections; import java.util.Map; import org.apache.spark.annotation.Evolving; @@ -28,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.Write; import org.apache.spark.sql.connector.write.WriterCommitMessage; import org.apache.spark.sql.errors.QueryCompilationErrors; import org.apache.spark.sql.types.StructType; @@ -204,7 +204,12 @@ default StagedTable stageCreateOrReplace( } /** - * @return A {@link java.lang.Iterable} of commit metrics that are supported by the catalog. + * @return An Array of commit metrics that are supported by the catalog. This is analogous to + * {@link Write#supportedCustomMetrics()}. The corresponding + * {@link StagedTableWithCommitMetrics#reportDriverMetrics()} method must be called to + * retrieve the actual metric values after a commit. The methods are not in the same class + * because the supported metrics are required before the staged table object is created + * and only the staged table object can capture the write metrics during the commit. */ - default Iterable supportedCommitMetrics() { return Collections.emptyList(); } + default CustomMetric[] supportedCustomMetrics() { return new CustomMetric[0]; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 624a2869ac64f..bbba24d0f1d59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -68,7 +68,7 @@ case class AtomicReplaceTableExec( val tableProperties = CatalogV2Util.convertTableProperties(tableSpec) override val metrics: Map[String, SQLMetric] = - catalog.supportedCommitMetrics().asScala.map { metric => + catalog.supportedCustomMetrics().map { metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) }.toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index bfe9d2e766863..55a9ba3d2e302 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -623,7 +623,7 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { protected def commitMetrics(tableCatalog: TableCatalog): Map[String, SQLMetric] = { tableCatalog match { case st: StagingTableCatalog => - st.supportedCommitMetrics().asScala.map { + st.supportedCustomMetrics().map { metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) }.toMap case _ => Map.empty diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala index c169d62ed54ad..9a9f29d77e073 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala @@ -31,7 +31,7 @@ class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog case class TestSupportedCommitMetric(name: String, description: String) extends CustomSumMetric - override def supportedCommitMetrics(): java.lang.Iterable[CustomMetric] = java.util.List.of( + override def supportedCustomMetrics(): Array[CustomMetric] = Array( TestSupportedCommitMetric("numFiles", "number of written files"), TestSupportedCommitMetric("numOutputRows", "number of output rows"), TestSupportedCommitMetric("numOutputBytes", "written output")) From 2faff80c66603c39c7ba80642b58f74e472855cc Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Thu, 21 Nov 2024 13:25:49 +0100 Subject: [PATCH 10/17] Remove StagedTableWithCommitMetrics and move reportDriverMetrics to StagedTable --- .../sql/connector/catalog/StagedTable.java | 14 +++++++ .../catalog/StagedTableWithCommitMetrics.java | 39 ------------------- .../catalog/StagingTableCatalog.java | 2 +- .../datasources/v2/ReplaceTableExec.scala | 18 ++++----- .../v2/WriteToDataSourceV2Exec.scala | 19 +++++---- .../connector/DataSourceV2MetricsSuite.scala | 4 +- 6 files changed, 36 insertions(+), 60 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTable.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTable.java index 60b250adb41ef..cbaea8cad8582 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTable.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTable.java @@ -21,7 +21,9 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.Write; import org.apache.spark.sql.types.StructType; /** @@ -52,4 +54,16 @@ public interface StagedTable extends Table { * table's writers. */ void abortStagedChanges(); + + /** + * Retrieve driver metrics after a commit. This is analogous + * to {@link Write#reportDriverMetrics()}. Note that these metrics must be included in the + * supported custom metrics reported by `supportedCustomMetrics` of the + * {@link StagingTableCatalog} that returned the staged table. + * + * @return an Array of commit metric values. Throws if the table has not been committed yet. + */ + default CustomTaskMetric[] reportDriverMetrics() throws RuntimeException { + return new CustomTaskMetric[0]; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java deleted file mode 100644 index 47e76ae88c904..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTableWithCommitMetrics.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.catalog; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.metric.CustomTaskMetric; -import org.apache.spark.sql.connector.write.Write; - -/** - * An extension of the {@link StagedTable} interface that allows to retrieve metrics after a commit. - */ -@Evolving -public interface StagedTableWithCommitMetrics extends StagedTable { - - /** - * Retrieve driver metrics after a commit. This is analogous - * to {@link Write#reportDriverMetrics()}. Note that these metrics must be included in the - * supported custom metrics reported by `supportedCustomMetrics` of the - * {@link StagingTableCatalog} that returned the staged table. - * - * @return an Array of commit metric values. Throws if the table has not been committed yet. - */ - CustomTaskMetric[] reportDriverMetrics() throws AssertionError; -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java index 39f7da6d05069..f457a4a3d7863 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java @@ -206,7 +206,7 @@ default StagedTable stageCreateOrReplace( /** * @return An Array of commit metrics that are supported by the catalog. This is analogous to * {@link Write#supportedCustomMetrics()}. The corresponding - * {@link StagedTableWithCommitMetrics#reportDriverMetrics()} method must be called to + * {@link StagedTable#reportDriverMetrics()} method must be called to * retrieve the actual metric values after a commit. The methods are not in the same class * because the supported metrics are required before the staged table object is created * and only the staged table object can capture the write metrics during the commit. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index bbba24d0f1d59..8e28bda5feb7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.TableSpec -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagedTableWithCommitMetrics, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.SQLExecution @@ -99,17 +99,15 @@ case class AtomicReplaceTableExec( private def commitOrAbortStagedChanges(staged: StagedTable): Unit = { Utils.tryWithSafeFinallyAndFailureCallbacks({ - staged match { - case st: StagedTableWithCommitMetrics => - st.commitStagedChanges() + staged.commitStagedChanges() - for (taskMetric <- st.reportDriverMetrics) { - metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) - } + if (catalog.supportedCustomMetrics().nonEmpty) { + for (taskMetric <- staged.reportDriverMetrics) { + metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) + } - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) - case st: StagedTable => st.commitStagedChanges() + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) } })(catchBlock = { staged.abortStagedChanges() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 55a9ba3d2e302..d54332a47db09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, TableSpec, UnaryNode} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, UPDATE_OPERATION} -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagedTableWithCommitMetrics, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage} @@ -649,16 +649,19 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { qe.assertCommandExecuted() table match { - case st: StagedTableWithCommitMetrics => + case st: StagedTable => st.commitStagedChanges() - for (taskMetric <- st.reportDriverMetrics) { - metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) - } + catalog match { + case stagingTableCatalog: StagingTableCatalog + if stagingTableCatalog.supportedCustomMetrics().nonEmpty => + for (taskMetric <- st.reportDriverMetrics) { + metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) + } - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) - case st: StagedTable => st.commitStagedChanges() + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + } case _ => } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala index 9a9f29d77e073..4c54b63730e5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector import java.util import org.apache.spark.sql.QueryTest.withPhysicalPlansCaptured -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryTable, InMemoryTableCatalog, StagedTable, StagedTableWithCommitMetrics, StagingInMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryTable, InMemoryTableCatalog, StagedTable, StagingInMemoryTableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric} @@ -39,7 +39,7 @@ class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog private class TestStagedTableWithMetric( ident: Identifier, delegateTable: InMemoryTable - ) extends TestStagedTable(ident, delegateTable) with StagedTableWithCommitMetrics { + ) extends TestStagedTable(ident, delegateTable) with StagedTable { private var stagedChangesCommitted = false From 4564caded87809bc3852d40d30cff83cbb46ff70 Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Mon, 25 Nov 2024 22:28:01 +0100 Subject: [PATCH 11/17] remove metrics for non atomic --- .../v2/WriteToDataSourceV2Exec.scala | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index d54332a47db09..2eed8cfa2440c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -75,8 +75,6 @@ case class CreateTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) - override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) - override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { if (ifNotExists) { @@ -652,15 +650,14 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { case st: StagedTable => st.commitStagedChanges() - catalog match { - case stagingTableCatalog: StagingTableCatalog - if stagingTableCatalog.supportedCustomMetrics().nonEmpty => - for (taskMetric <- st.reportDriverMetrics) { - metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) - } + val driverMetrics = st.reportDriverMetrics() + if (driverMetrics.nonEmpty) { + for (taskMetric <- driverMetrics) { + metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) + } - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) } case _ => } From 4462b48e2a20473cd30a956a34ba4ecf9e3f7d9b Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Tue, 26 Nov 2024 09:16:01 +0100 Subject: [PATCH 12/17] Used renamed withQueryExecutionsCaptured --- .../apache/spark/sql/connector/DataSourceV2MetricsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala index 4c54b63730e5d..fe28b85528632 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connector import java.util -import org.apache.spark.sql.QueryTest.withPhysicalPlansCaptured +import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryTable, InMemoryTableCatalog, StagedTable, StagingInMemoryTableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper import org.apache.spark.sql.connector.expressions.Transform @@ -98,7 +98,7 @@ class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { private val existingTable = "existing_table" private def captureStagedTableWrite(thunk: => Unit): SparkPlan = { - val physicalPlans = withPhysicalPlansCaptured(spark, thunk) + val physicalPlans = withQueryExecutionsCaptured(spark)(thunk).map(_.executedPlan) val stagedTableWrites = physicalPlans.filter { case _: AtomicCreateTableAsSelectExec | _: CreateTableAsSelectExec | _: AtomicReplaceTableAsSelectExec | _: ReplaceTableAsSelectExec | From e1eaf8a41109e205fc06da0671e73d35e07e014e Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Tue, 26 Nov 2024 19:38:41 +0100 Subject: [PATCH 13/17] only add commitMetrics when we have a staged table --- .../datasources/v2/WriteToDataSourceV2Exec.scala | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 2eed8cfa2440c..4c1a8fdb27671 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -150,8 +150,6 @@ case class ReplaceTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) - override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) - override protected def run(): Seq[InternalRow] = { // Note that this operation is potentially unsafe, but these are the strict semantics of // RTAS if the catalog does not support atomic operations. @@ -618,14 +616,10 @@ case class DeltaWithMetadataWritingSparkTask( private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { override def output: Seq[Attribute] = Nil - protected def commitMetrics(tableCatalog: TableCatalog): Map[String, SQLMetric] = { - tableCatalog match { - case st: StagingTableCatalog => - st.supportedCustomMetrics().map { - metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) - }.toMap - case _ => Map.empty - } + protected def commitMetrics(tableCatalog: StagingTableCatalog): Map[String, SQLMetric] = { + tableCatalog.supportedCustomMetrics().map { + metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) + }.toMap } protected def getV2Columns(schema: StructType, forceNullable: Boolean): Array[Column] = { From c8019fc0254036f5153ea1907a34a748abf2f86f Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Wed, 27 Nov 2024 09:18:49 +0100 Subject: [PATCH 14/17] Move more code to shared classes --- .../datasources/v2/DataSourceV2Utils.scala | 33 ++++++++++- .../datasources/v2/ReplaceTableExec.scala | 23 +------- .../v2/WriteToDataSourceV2Exec.scala | 56 ++++++------------- 3 files changed, 52 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 9ffa0d728ca28..6c9177a1fd40e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -23,19 +23,23 @@ import scala.jdk.CollectionConverters._ import com.fasterxml.jackson.databind.ObjectMapper +import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.TimeTravelSpec import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SessionConfigSupport, SupportsCatalogOptions, SupportsRead, Table, TableProvider} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SessionConfigSupport, StagedTable, SupportsCatalogOptions, SupportsRead, Table, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.Utils private[sql] object DataSourceV2Utils extends Logging { @@ -179,4 +183,31 @@ private[sql] object DataSourceV2Utils extends Logging { extraOptions + ("paths" -> objectMapper.writeValueAsString(paths.toArray)) } } + + /** + * If `table` is a StagedTable, commit the staged changes and report the commit metrics. + * Abort the changes in case of an error. Do nothing if the table is not a StagedTable. + */ + def commitOrAbortStagedChanges( + sparkContext: SparkContext, table: Table, metrics: Map[String, SQLMetric]): Unit = { + table match { + case stagedTable: StagedTable => + Utils.tryWithSafeFinallyAndFailureCallbacks({ + stagedTable.commitStagedChanges() + + val driverMetrics = stagedTable.reportDriverMetrics() + if (driverMetrics.nonEmpty) { + for (taskMetric <- driverMetrics) { + metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) + } + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + } + })(catchBlock = { + stagedTable.abortStagedChanges() + }) + case _ => + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 8e28bda5feb7a..07dfee56305b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -23,12 +23,10 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.TableSpec -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.util.Utils case class ReplaceTableExec( catalog: TableCatalog, @@ -91,26 +89,9 @@ case class AtomicReplaceTableExec( } else { throw QueryCompilationErrors.cannotReplaceMissingTableError(identifier) } - commitOrAbortStagedChanges(staged) + DataSourceV2Utils.commitOrAbortStagedChanges(sparkContext, staged, metrics) Seq.empty } override def output: Seq[Attribute] = Seq.empty - - private def commitOrAbortStagedChanges(staged: StagedTable): Unit = { - Utils.tryWithSafeFinallyAndFailureCallbacks({ - staged.commitStagedChanges() - - if (catalog.supportedCustomMetrics().nonEmpty) { - for (taskMetric <- staged.reportDriverMetrics) { - metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) - } - - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) - } - })(catchBlock = { - staged.abortStagedChanges() - }) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 4c1a8fdb27671..a8c34ff553855 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, TableSpec, UnaryNode} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, UPDATE_OPERATION} -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage} @@ -110,8 +110,6 @@ case class AtomicCreateTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) - override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) - override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { if (ifNotExists) { @@ -199,8 +197,6 @@ case class AtomicReplaceTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) - override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) - override protected def run(): Seq[InternalRow] = { val columns = getV2Columns(query.schema, catalog.useNullableQuerySchema) if (catalog.tableExists(ident)) { @@ -616,10 +612,16 @@ case class DeltaWithMetadataWritingSparkTask( private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { override def output: Seq[Attribute] = Nil - protected def commitMetrics(tableCatalog: StagingTableCatalog): Map[String, SQLMetric] = { - tableCatalog.supportedCustomMetrics().map { - metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) - }.toMap + def catalog: TableCatalog + + override val metrics: Map[String, SQLMetric] = { + catalog match { + case stagingCatalog: StagingTableCatalog => + stagingCatalog.supportedCustomMetrics().map { + metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) + }.toMap + case _ => Map.empty + } } protected def getV2Columns(schema: StructType, forceNullable: Boolean): Array[Column] = { @@ -634,36 +636,14 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { writeOptions: Map[String, String], ident: Identifier, query: LogicalPlan): Seq[InternalRow] = { - Utils.tryWithSafeFinallyAndFailureCallbacks({ - val relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) - val append = AppendData.byPosition(relation, query, writeOptions) - val qe = session.sessionState.executePlan(append) - qe.assertCommandExecuted() - - table match { - case st: StagedTable => - st.commitStagedChanges() - - val driverMetrics = st.reportDriverMetrics() - if (driverMetrics.nonEmpty) { - for (taskMetric <- driverMetrics) { - metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) - } - - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) - } - case _ => - } + val relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + val append = AppendData.byPosition(relation, query, writeOptions) + val qe = session.sessionState.executePlan(append) + qe.assertCommandExecuted() - Nil - })(catchBlock = { - table match { - // Failure rolls back the staged writes and metadata changes. - case st: StagedTable => st.abortStagedChanges() - case _ => catalog.dropTable(ident) - } - }) + DataSourceV2Utils.commitOrAbortStagedChanges(sparkContext, table, metrics) + + Nil } } From 9f8e98243e908441b31643aedc3deb8d1776ce10 Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Wed, 27 Nov 2024 11:44:10 +0100 Subject: [PATCH 15/17] Move error handling out of DataSourceV2Utils --- .../datasources/v2/DataSourceV2Utils.scala | 27 +++++-------- .../datasources/v2/ReplaceTableExec.scala | 13 +++++- .../v2/WriteToDataSourceV2Exec.scala | 40 +++++++++++-------- 3 files changed, 45 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 6c9177a1fd40e..f8b878fac0a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -39,7 +39,6 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.util.Utils private[sql] object DataSourceV2Utils extends Logging { @@ -186,27 +185,23 @@ private[sql] object DataSourceV2Utils extends Logging { /** * If `table` is a StagedTable, commit the staged changes and report the commit metrics. - * Abort the changes in case of an error. Do nothing if the table is not a StagedTable. + * Do nothing if the table is not a StagedTable. */ - def commitOrAbortStagedChanges( + def commitStagedChanges( sparkContext: SparkContext, table: Table, metrics: Map[String, SQLMetric]): Unit = { table match { case stagedTable: StagedTable => - Utils.tryWithSafeFinallyAndFailureCallbacks({ - stagedTable.commitStagedChanges() + stagedTable.commitStagedChanges() - val driverMetrics = stagedTable.reportDriverMetrics() - if (driverMetrics.nonEmpty) { - for (taskMetric <- driverMetrics) { - metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) - } - - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + val driverMetrics = stagedTable.reportDriverMetrics() + if (driverMetrics.nonEmpty) { + for (taskMetric <- driverMetrics) { + metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) } - })(catchBlock = { - stagedTable.abortStagedChanges() - }) + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + } case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 07dfee56305b6..ca95641235316 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -23,10 +23,11 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.TableSpec -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.util.Utils case class ReplaceTableExec( catalog: TableCatalog, @@ -89,9 +90,17 @@ case class AtomicReplaceTableExec( } else { throw QueryCompilationErrors.cannotReplaceMissingTableError(identifier) } - DataSourceV2Utils.commitOrAbortStagedChanges(sparkContext, staged, metrics) + commitOrAbortStagedChanges(staged) Seq.empty } override def output: Seq[Attribute] = Seq.empty + + private def commitOrAbortStagedChanges(staged: StagedTable): Unit = { + Utils.tryWithSafeFinallyAndFailureCallbacks({ + DataSourceV2Utils.commitStagedChanges(sparkContext, staged, metrics) + })(catchBlock = { + staged.abortStagedChanges() + }) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index a8c34ff553855..b0c636c1e51ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, TableSpec, UnaryNode} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, UPDATE_OPERATION} -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage} @@ -110,6 +110,8 @@ case class AtomicCreateTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) + override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) + override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { if (ifNotExists) { @@ -197,6 +199,8 @@ case class AtomicReplaceTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) + override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) + override protected def run(): Seq[InternalRow] = { val columns = getV2Columns(query.schema, catalog.useNullableQuerySchema) if (catalog.tableExists(ident)) { @@ -612,16 +616,10 @@ case class DeltaWithMetadataWritingSparkTask( private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { override def output: Seq[Attribute] = Nil - def catalog: TableCatalog - - override val metrics: Map[String, SQLMetric] = { - catalog match { - case stagingCatalog: StagingTableCatalog => - stagingCatalog.supportedCustomMetrics().map { - metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) - }.toMap - case _ => Map.empty - } + protected def commitMetrics(tableCatalog: StagingTableCatalog): Map[String, SQLMetric] = { + tableCatalog.supportedCustomMetrics().map { + metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) + }.toMap } protected def getV2Columns(schema: StructType, forceNullable: Boolean): Array[Column] = { @@ -636,14 +634,22 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { writeOptions: Map[String, String], ident: Identifier, query: LogicalPlan): Seq[InternalRow] = { - val relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) - val append = AppendData.byPosition(relation, query, writeOptions) - val qe = session.sessionState.executePlan(append) - qe.assertCommandExecuted() + Utils.tryWithSafeFinallyAndFailureCallbacks({ + val relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + val append = AppendData.byPosition(relation, query, writeOptions) + val qe = session.sessionState.executePlan(append) + qe.assertCommandExecuted() - DataSourceV2Utils.commitOrAbortStagedChanges(sparkContext, table, metrics) + DataSourceV2Utils.commitStagedChanges(sparkContext, table, metrics) - Nil + Nil + })(catchBlock = { + table match { + // Failure rolls back the staged writes and metadata changes. + case st: StagedTable => st.abortStagedChanges() + case _ => catalog.dropTable(ident) + } + }) } } From 46d098a8840081d2d1c50ccba006418a1fc65a1f Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Wed, 27 Nov 2024 17:00:10 +0100 Subject: [PATCH 16/17] Move commitMetrics to DataSouceV2Utils --- .../execution/datasources/v2/DataSourceV2Utils.scala | 10 ++++++++-- .../execution/datasources/v2/ReplaceTableExec.scala | 6 ++---- .../datasources/v2/WriteToDataSourceV2Exec.scala | 12 ++++-------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index f8b878fac0a33..c474c51666006 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -22,14 +22,13 @@ import java.util.regex.Pattern import scala.jdk.CollectionConverters._ import com.fasterxml.jackson.databind.ObjectMapper - import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.TimeTravelSpec import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SessionConfigSupport, StagedTable, SupportsCatalogOptions, SupportsRead, Table, TableProvider} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SessionConfigSupport, StagedTable, StagingTableCatalog, SupportsCatalogOptions, SupportsRead, Table, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SQLExecution @@ -205,4 +204,11 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => } } + + def commitMetrics( + sparkContext: SparkContext, tableCatalog: StagingTableCatalog): Map[String, SQLMetric] = { + tableCatalog.supportedCustomMetrics().map { + metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) + }.toMap + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index ca95641235316..894a3a10d4193 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.TableSpec import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.util.Utils case class ReplaceTableExec( @@ -67,9 +67,7 @@ case class AtomicReplaceTableExec( val tableProperties = CatalogV2Util.convertTableProperties(tableSpec) override val metrics: Map[String, SQLMetric] = - catalog.supportedCustomMetrics().map { metric => - metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) - }.toMap + DataSourceV2Utils.commitMetrics(sparkContext, catalog) override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(identifier)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index b0c636c1e51ab..bdcf7b8260a7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -110,7 +110,8 @@ case class AtomicCreateTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) - override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) + override val metrics: Map[String, SQLMetric] = + DataSourceV2Utils.commitMetrics(sparkContext, catalog) override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { @@ -199,7 +200,8 @@ case class AtomicReplaceTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) - override val metrics: Map[String, SQLMetric] = commitMetrics(catalog) + override val metrics: Map[String, SQLMetric] = + DataSourceV2Utils.commitMetrics(sparkContext, catalog) override protected def run(): Seq[InternalRow] = { val columns = getV2Columns(query.schema, catalog.useNullableQuerySchema) @@ -616,12 +618,6 @@ case class DeltaWithMetadataWritingSparkTask( private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { override def output: Seq[Attribute] = Nil - protected def commitMetrics(tableCatalog: StagingTableCatalog): Map[String, SQLMetric] = { - tableCatalog.supportedCustomMetrics().map { - metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) - }.toMap - } - protected def getV2Columns(schema: StructType, forceNullable: Boolean): Array[Column] = { val rawSchema = CharVarcharUtils.getRawSchema(removeInternalMetadata(schema), conf) val tableSchema = if (forceNullable) rawSchema.asNullable else rawSchema From 41a6fafc035f4cca321c076e45064e20b05181a6 Mon Sep 17 00:00:00 2001 From: Ole Sasse Date: Wed, 27 Nov 2024 18:48:39 +0100 Subject: [PATCH 17/17] fix scalastyle --- .../spark/sql/execution/datasources/v2/DataSourceV2Utils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index c474c51666006..9c19609dce79a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -22,6 +22,7 @@ import java.util.regex.Pattern import scala.jdk.CollectionConverters._ import com.fasterxml.jackson.databind.ObjectMapper + import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}