diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTable.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTable.java index 60b250adb41ef..cbaea8cad8582 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTable.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagedTable.java @@ -21,7 +21,9 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; import org.apache.spark.sql.connector.write.LogicalWriteInfo; +import org.apache.spark.sql.connector.write.Write; import org.apache.spark.sql.types.StructType; /** @@ -52,4 +54,16 @@ public interface StagedTable extends Table { * table's writers. */ void abortStagedChanges(); + + /** + * Retrieve driver metrics after a commit. This is analogous + * to {@link Write#reportDriverMetrics()}. Note that these metrics must be included in the + * supported custom metrics reported by `supportedCustomMetrics` of the + * {@link StagingTableCatalog} that returned the staged table. + * + * @return an Array of commit metric values. Throws if the table has not been committed yet. + */ + default CustomTaskMetric[] reportDriverMetrics() throws RuntimeException { + return new CustomTaskMetric[0]; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java index eead1ade40791..f457a4a3d7863 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/StagingTableCatalog.java @@ -21,11 +21,13 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.metric.CustomMetric; import org.apache.spark.sql.connector.write.LogicalWriteInfo; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; import org.apache.spark.sql.connector.write.BatchWrite; +import org.apache.spark.sql.connector.write.Write; import org.apache.spark.sql.connector.write.WriterCommitMessage; import org.apache.spark.sql.errors.QueryCompilationErrors; import org.apache.spark.sql.types.StructType; @@ -200,4 +202,14 @@ default StagedTable stageCreateOrReplace( return stageCreateOrReplace( ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties); } + + /** + * @return An Array of commit metrics that are supported by the catalog. This is analogous to + * {@link Write#supportedCustomMetrics()}. The corresponding + * {@link StagedTable#reportDriverMetrics()} method must be called to + * retrieve the actual metric values after a commit. The methods are not in the same class + * because the supported metrics are required before the staged table object is created + * and only the staged table object can capture the write metrics during the commit. + */ + default CustomMetric[] supportedCustomMetrics() { return new CustomMetric[0]; } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala index f3c7bc98cec09..2a207901b83f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala @@ -78,7 +78,7 @@ class StagingInMemoryTableCatalog extends InMemoryTableCatalog with StagingTable maybeSimulateFailedTableCreation(properties) } - private abstract class TestStagedTable( + protected abstract class TestStagedTable( ident: Identifier, delegateTable: InMemoryTable) extends StagedTable with SupportsWrite with SupportsRead { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 9ffa0d728ca28..9c19609dce79a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -23,16 +23,19 @@ import scala.jdk.CollectionConverters._ import com.fasterxml.jackson.databind.ObjectMapper +import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.TimeTravelSpec import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SessionConfigSupport, SupportsCatalogOptions, SupportsRead, Table, TableProvider} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SessionConfigSupport, StagedTable, StagingTableCatalog, SupportsCatalogOptions, SupportsRead, Table, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -179,4 +182,34 @@ private[sql] object DataSourceV2Utils extends Logging { extraOptions + ("paths" -> objectMapper.writeValueAsString(paths.toArray)) } } + + /** + * If `table` is a StagedTable, commit the staged changes and report the commit metrics. + * Do nothing if the table is not a StagedTable. + */ + def commitStagedChanges( + sparkContext: SparkContext, table: Table, metrics: Map[String, SQLMetric]): Unit = { + table match { + case stagedTable: StagedTable => + stagedTable.commitStagedChanges() + + val driverMetrics = stagedTable.reportDriverMetrics() + if (driverMetrics.nonEmpty) { + for (taskMetric <- driverMetrics) { + metrics.get(taskMetric.name()).foreach(_.set(taskMetric.value())) + } + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + } + case _ => + } + } + + def commitMetrics( + sparkContext: SparkContext, tableCatalog: StagingTableCatalog): Map[String, SQLMetric] = { + tableCatalog.supportedCustomMetrics().map { + metric => metric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, metric) + }.toMap + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala index 104d8a706efb7..894a3a10d4193 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.TableSpec import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.util.Utils case class ReplaceTableExec( @@ -65,6 +66,9 @@ case class AtomicReplaceTableExec( val tableProperties = CatalogV2Util.convertTableProperties(tableSpec) + override val metrics: Map[String, SQLMetric] = + DataSourceV2Utils.commitMetrics(sparkContext, catalog) + override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(identifier)) { val table = catalog.loadTable(identifier) @@ -92,7 +96,7 @@ case class AtomicReplaceTableExec( private def commitOrAbortStagedChanges(staged: StagedTable): Unit = { Utils.tryWithSafeFinallyAndFailureCallbacks({ - staged.commitStagedChanges() + DataSourceV2Utils.commitStagedChanges(sparkContext, staged, metrics) })(catchBlock = { staged.abortStagedChanges() }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index b238b0ce9760c..bdcf7b8260a7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -110,6 +110,9 @@ case class AtomicCreateTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) + override val metrics: Map[String, SQLMetric] = + DataSourceV2Utils.commitMetrics(sparkContext, catalog) + override protected def run(): Seq[InternalRow] = { if (catalog.tableExists(ident)) { if (ifNotExists) { @@ -197,6 +200,9 @@ case class AtomicReplaceTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) + override val metrics: Map[String, SQLMetric] = + DataSourceV2Utils.commitMetrics(sparkContext, catalog) + override protected def run(): Seq[InternalRow] = { val columns = getV2Columns(query.schema, catalog.useNullableQuerySchema) if (catalog.tableExists(ident)) { @@ -630,10 +636,7 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { val qe = session.sessionState.executePlan(append) qe.assertCommandExecuted() - table match { - case st: StagedTable => st.commitStagedChanges() - case _ => - } + DataSourceV2Utils.commitStagedChanges(sparkContext, table, metrics) Nil })(catchBlock = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala new file mode 100644 index 0000000000000..fe28b85528632 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetricsSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util + +import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, InMemoryTable, InMemoryTableCatalog, StagedTable, StagingInMemoryTableCatalog} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec, AtomicReplaceTableExec, CreateTableAsSelectExec, ReplaceTableAsSelectExec, ReplaceTableExec} + +class StagingInMemoryTableCatalogWithMetrics extends StagingInMemoryTableCatalog { + + case class TestSupportedCommitMetric(name: String, description: String) extends CustomSumMetric + + override def supportedCustomMetrics(): Array[CustomMetric] = Array( + TestSupportedCommitMetric("numFiles", "number of written files"), + TestSupportedCommitMetric("numOutputRows", "number of output rows"), + TestSupportedCommitMetric("numOutputBytes", "written output")) + + private class TestStagedTableWithMetric( + ident: Identifier, + delegateTable: InMemoryTable + ) extends TestStagedTable(ident, delegateTable) with StagedTable { + + private var stagedChangesCommitted = false + + override def commitStagedChanges(): Unit = { + tables.put(ident, delegateTable) + stagedChangesCommitted = true + } + + override def reportDriverMetrics: Array[CustomTaskMetric] = { + assert(stagedChangesCommitted) + StagingInMemoryTableCatalogWithMetrics.testMetrics + } + } + + override def stageCreate( + ident: Identifier, + columns: Array[Column], + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = { + new TestStagedTableWithMetric( + ident, + new InMemoryTable(s"$name.${ident.quoted}", + CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties)) + } + + override def stageReplace( + ident: Identifier, + columns: Array[Column], + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + stageCreate(ident, columns, partitions, properties) + + override def stageCreateOrReplace( + ident: Identifier, + columns: Array[Column], + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + stageCreate(ident, columns, partitions, properties) +} + +object StagingInMemoryTableCatalogWithMetrics { + + case class TestCustomTaskMetric(name: String, value: Long) extends CustomTaskMetric + + val testMetrics: Array[CustomTaskMetric] = Array( + TestCustomTaskMetric("numFiles", 1337), + TestCustomTaskMetric("numOutputRows", 1338), + TestCustomTaskMetric("numOutputBytes", 1339)) +} + +class DataSourceV2MetricsSuite extends DatasourceV2SQLBase { + + private val testCatalog = "test_catalog" + private val atomicTestCatalog = "atomic_test_catalog" + private val nonExistingTable = "non_existing_table" + private val existingTable = "existing_table" + + private def captureStagedTableWrite(thunk: => Unit): SparkPlan = { + val physicalPlans = withQueryExecutionsCaptured(spark)(thunk).map(_.executedPlan) + val stagedTableWrites = physicalPlans.filter { + case _: AtomicCreateTableAsSelectExec | _: CreateTableAsSelectExec | + _: AtomicReplaceTableAsSelectExec | _: ReplaceTableAsSelectExec | + _: AtomicReplaceTableExec | _: ReplaceTableExec => true + case _ => false + } + assert(stagedTableWrites.size === 1) + stagedTableWrites.head + } + + private def commands: Seq[String => Unit] = Seq( + { catalogName => + sql(s"CREATE TABLE $catalogName.$nonExistingTable AS SELECT * FROM $existingTable") }, + { catalogName => + spark.table(existingTable).write.saveAsTable(s"$catalogName.$nonExistingTable") }, + { catalogName => + sql(s"CREATE OR REPLACE TABLE $catalogName.$nonExistingTable " + + s"AS SELECT * FROM $existingTable") }, + { catalogName => + sql(s"REPLACE TABLE $catalogName.$existingTable AS SELECT * FROM $existingTable") }, + { catalogName => + spark.table(existingTable) + .write.mode("overwrite").saveAsTable(s"$catalogName.$existingTable") }, + { catalogName => + sql(s"REPLACE TABLE $catalogName.$existingTable (id bigint, data string)") }) + + private def catalogCommitMetricsTest( + testName: String, catalogName: String)(testFunction: SparkPlan => Unit): Unit = { + commands.foreach { command => + test(s"$testName - $command") { + registerCatalog(testCatalog, classOf[InMemoryTableCatalog]) + registerCatalog(atomicTestCatalog, classOf[StagingInMemoryTableCatalogWithMetrics]) + withTable(existingTable, s"$catalogName.$existingTable") { + sql(s"CREATE TABLE $existingTable (id bigint, data string)") + sql(s"CREATE TABLE $catalogName.$existingTable (id bigint, data string)") + + testFunction(captureStagedTableWrite(command(catalogName))) + } + } + } + } + + catalogCommitMetricsTest( + "No metrics in the plan if the catalog does not support them", testCatalog) { sparkPlan => + val metrics = sparkPlan.metrics + + assert(metrics.isEmpty) + } + + catalogCommitMetricsTest( + "Plan metrics values are the values from the catalog", atomicTestCatalog) { sparkPlan => + val metrics = sparkPlan.metrics + + assert(metrics.size === StagingInMemoryTableCatalogWithMetrics.testMetrics.length) + StagingInMemoryTableCatalogWithMetrics.testMetrics.foreach(customTaskMetric => + assert(metrics(customTaskMetric.name()).value === customTaskMetric.value())) + } +}