Skip to content

Commit

Permalink
[SPARK-49211][SQL] V2 Catalog can also support built-in data sources
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

V2 Catalog can also support built-in data sources.

### Why are the changes needed?

V2 catalog could still support spark built-in data sources if the V2 catalog returns v1 table and do not track partitions in catalog. This is because we do not need to require V2 catalog to implement every thing to support built-in data sources (as that is a big chunk of work).

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

UT

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #47723 from amaliujia/v2catalog_can_support_built_in_data_source.

Lead-authored-by: Rui Wang <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
amaliujia and cloud-fan committed Aug 15, 2024
1 parent 0219d60 commit b9fbdf0
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
options: CaseInsensitiveStringMap,
isStreaming: Boolean): Option[LogicalPlan] = {
table.map {
case v1Table: V1Table if CatalogV2Util.isSessionCatalog(catalog) =>
// To utilize this code path to execute V1 commands, e.g. INSERT,
// either it must be session catalog, or tracksPartitionsInCatalog
// must be false so it does not require use catalog to manage partitions.
// Obviously we cannot execute V1Table by V1 code path if the table
// is not from session catalog and the table still requires its catalog
// to manage partitions.
case v1Table: V1Table if CatalogV2Util.isSessionCatalog(catalog)
|| !v1Table.catalogTable.tracksPartitionsInCatalog =>
if (isStreaming) {
if (v1Table.v1Table.tableType == CatalogTableType.VIEW) {
throw QueryCompilationErrors.permanentViewNotSupportedByStreamingReadingAPIError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Subque
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils}
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE
Expand Down Expand Up @@ -196,41 +197,42 @@ class SessionCatalog(
}
}

private val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = {
private val tableRelationCache: Cache[FullQualifiedTableName, LogicalPlan] = {
var builder = CacheBuilder.newBuilder()
.maximumSize(cacheSize)

if (cacheTTL > 0) {
builder = builder.expireAfterWrite(cacheTTL, TimeUnit.SECONDS)
}

builder.build[QualifiedTableName, LogicalPlan]()
builder.build[FullQualifiedTableName, LogicalPlan]()
}

/** This method provides a way to get a cached plan. */
def getCachedPlan(t: QualifiedTableName, c: Callable[LogicalPlan]): LogicalPlan = {
def getCachedPlan(t: FullQualifiedTableName, c: Callable[LogicalPlan]): LogicalPlan = {
tableRelationCache.get(t, c)
}

/** This method provides a way to get a cached plan if the key exists. */
def getCachedTable(key: QualifiedTableName): LogicalPlan = {
def getCachedTable(key: FullQualifiedTableName): LogicalPlan = {
tableRelationCache.getIfPresent(key)
}

/** This method provides a way to cache a plan. */
def cacheTable(t: QualifiedTableName, l: LogicalPlan): Unit = {
def cacheTable(t: FullQualifiedTableName, l: LogicalPlan): Unit = {
tableRelationCache.put(t, l)
}

/** This method provides a way to invalidate a cached plan. */
def invalidateCachedTable(key: QualifiedTableName): Unit = {
def invalidateCachedTable(key: FullQualifiedTableName): Unit = {
tableRelationCache.invalidate(key)
}

/** This method discards any cached table relation plans for the given table identifier. */
def invalidateCachedTable(name: TableIdentifier): Unit = {
val qualified = qualifyIdentifier(name)
invalidateCachedTable(QualifiedTableName(qualified.database.get, qualified.table))
invalidateCachedTable(FullQualifiedTableName(
qualified.catalog.get, qualified.database.get, qualified.table))
}

/** This method provides a way to invalidate all the cached plans. */
Expand Down Expand Up @@ -299,7 +301,7 @@ class SessionCatalog(
}
if (cascade && databaseExists(dbName)) {
listTables(dbName).foreach { t =>
invalidateCachedTable(QualifiedTableName(dbName, t.table))
invalidateCachedTable(FullQualifiedTableName(SESSION_CATALOG_NAME, dbName, t.table))
}
}
externalCatalog.dropDatabase(dbName, ignoreIfNotExists, cascade)
Expand Down Expand Up @@ -1181,7 +1183,8 @@ class SessionCatalog(
def refreshTable(name: TableIdentifier): Unit = synchronized {
getLocalOrGlobalTempView(name).map(_.refresh()).getOrElse {
val qualifiedIdent = qualifyIdentifier(name)
val qualifiedTableName = QualifiedTableName(qualifiedIdent.database.get, qualifiedIdent.table)
val qualifiedTableName = FullQualifiedTableName(
qualifiedIdent.catalog.get, qualifiedIdent.database.get, qualifiedIdent.table)
tableRelationCache.invalidate(qualifiedTableName)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ case class QualifiedTableName(database: String, name: String) {
override def toString: String = s"$database.$name"
}

case class FullQualifiedTableName(catalog: String, database: String, name: String) {
override def toString: String = s"$catalog.$database.$name"
}

object TableIdentifier {
def apply(tableName: String): TableIdentifier = new TableIdentifier(tableName)
def apply(table: String, database: Option[String]): TableIdentifier =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.concurrent.duration._
import org.scalatest.concurrent.Eventually

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, QualifiedTableName, TableIdentifier}
import org.apache.spark.sql.catalyst.{AliasIdentifier, FullQualifiedTableName, FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
Expand Down Expand Up @@ -1883,7 +1883,8 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually {
conf.setConf(StaticSQLConf.METADATA_CACHE_TTL_SECONDS, 1L)

withConfAndEmptyCatalog(conf) { catalog =>
val table = QualifiedTableName(catalog.getCurrentDatabase, "test")
val table = FullQualifiedTableName(
CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, "test")

// First, make sure the test table is not cached.
assert(catalog.getCachedTable(table) === null)
Expand All @@ -1902,13 +1903,14 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually {
test("SPARK-34197: refreshTable should not invalidate the relation cache for temporary views") {
withBasicCatalog { catalog =>
createTempView(catalog, "tbl1", Range(1, 10, 1, 10), false)
val qualifiedName1 = QualifiedTableName("default", "tbl1")
val qualifiedName1 = FullQualifiedTableName(SESSION_CATALOG_NAME, "default", "tbl1")
catalog.cacheTable(qualifiedName1, Range(1, 10, 1, 10))
catalog.refreshTable(TableIdentifier("tbl1"))
assert(catalog.getCachedTable(qualifiedName1) != null)

createGlobalTempView(catalog, "tbl2", Range(2, 10, 1, 10), false)
val qualifiedName2 = QualifiedTableName(catalog.globalTempDatabase, "tbl2")
val qualifiedName2 =
FullQualifiedTableName(SESSION_CATALOG_NAME, catalog.globalTempDatabase, "tbl2")
catalog.cacheTable(qualifiedName2, Range(2, 10, 1, 10))
catalog.refreshTable(TableIdentifier("tbl2", Some(catalog.globalTempDatabase)))
assert(catalog.getCachedTable(qualifiedName2) != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,27 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.PREDICATES
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, SQLConfHelper}
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, FullQualifiedTableName, InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{GeneratedColumn, ResolveDefaultColumns, V2ExpressionBuilder}
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.{SupportsRead, V1Table}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, PushedDownOperators}
import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -244,7 +243,8 @@ object DataSourceAnalysis extends Rule[LogicalPlan] {
class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] {
private def readDataSourceTable(
table: CatalogTable, extraOptions: CaseInsensitiveStringMap): LogicalPlan = {
val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table)
val qualifiedTableName =
FullQualifiedTableName(table.identifier.catalog.get, table.database, table.identifier.table)
val catalog = sparkSession.sessionState.catalog
val dsOptions = DataSourceUtils.generateDatasourceOptions(extraOptions, table)
catalog.getCachedPlan(qualifiedTableName, () => {
Expand Down Expand Up @@ -286,6 +286,13 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
_, _, _, _, _, _) =>
i.copy(table = DDLUtils.readHiveTable(tableMeta))

case append @ AppendData(
DataSourceV2Relation(
V1Table(table: CatalogTable), _, _, _, _), _, _, _, _, _) if !append.isByName =>
InsertIntoStatement(UnresolvedCatalogRelation(table),
table.partitionColumnNames.map(name => name -> None).toMap,
Seq.empty, append.query, false, append.isByName)

case UnresolvedCatalogRelation(tableMeta, options, false)
if DDLUtils.isDatasourceTable(tableMeta) =>
readDataSourceTable(tableMeta, options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.mutable
import scala.jdk.CollectionConverters._

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, QualifiedTableName, SQLConfHelper, TableIdentifier}
import org.apache.spark.sql.catalyst.{FullQualifiedTableName, FunctionIdentifier, SQLConfHelper, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils, ClusterBySpec, SessionCatalog}
import org.apache.spark.sql.catalyst.util.TypeUtils._
Expand Down Expand Up @@ -93,7 +93,8 @@ class V2SessionCatalog(catalog: SessionCatalog)
// table here. To avoid breaking it we do not resolve the table provider and still return
// `V1Table` if the custom session catalog is present.
if (table.provider.isDefined && !hasCustomSessionCatalog) {
val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table)
val qualifiedTableName = FullQualifiedTableName(
table.identifier.catalog.get, table.database, table.identifier.table)
// Check if the table is in the v1 table cache to skip the v2 table lookup.
if (catalog.getCachedTable(qualifiedTableName) != null) {
return V1Table(table)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ import java.time.LocalDateTime
import scala.collection.mutable
import scala.util.Random

import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier}
import org.apache.spark.sql.catalyst.{FullQualifiedTableName, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTable, HiveTableRelation}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.AttributeMap
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, HistogramSerializer, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.test.SQLTestUtils
Expand Down Expand Up @@ -269,7 +270,8 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils

def getTableFromCatalogCache(tableName: String): LogicalPlan = {
val catalog = spark.sessionState.catalog
val qualifiedTableName = QualifiedTableName(catalog.getCurrentDatabase, tableName)
val qualifiedTableName = FullQualifiedTableName(
CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, tableName)
catalog.getCachedTable(qualifiedTableName)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.{SparkException, SparkRuntimeException, SparkUnsupportedOperationException}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{FullQualifiedTableName, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchNamespaceException, TableAlreadyExistsException}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, ColumnStat, CommandResult, OverwriteByExpression}
import org.apache.spark.sql.catalyst.statsEstimation.StatsEstimationTestBase
Expand All @@ -42,6 +43,7 @@ import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.execution.FilterExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.streaming.MemoryStream
Expand Down Expand Up @@ -3638,6 +3640,69 @@ class DataSourceV2SQLSuiteV1Filter
}
}

test("SPARK-49211: V2 Catalog can support built-in data sources") {
def checkParquet(tableName: String, path: String): Unit = {
withTable(tableName) {
sql("CREATE TABLE " + tableName +
" (name STRING) USING PARQUET LOCATION '" + path + "'")
sql("INSERT INTO " + tableName + " VALUES('Bob')")
val df = sql("SELECT * FROM " + tableName)
assert(df.queryExecution.analyzed.exists {
case LogicalRelation(_: HadoopFsRelation, _, _, _) => true
case _ => false
})
checkAnswer(df, Row("Bob"))
}
}

// Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can
// configure a new implementation.
val table1 = FullQualifiedTableName(SESSION_CATALOG_NAME, "default", "t")
spark.sessionState.catalogManager.reset()
withSQLConf(
V2_SESSION_CATALOG_IMPLEMENTATION.key ->
classOf[V2CatalogSupportBuiltinDataSource].getName) {
withTempPath { path =>
checkParquet(table1.toString, path.getAbsolutePath)
}
}
val table2 = FullQualifiedTableName("testcat3", "default", "t")
withSQLConf(
"spark.sql.catalog.testcat3" -> classOf[V2CatalogSupportBuiltinDataSource].getName) {
withTempPath { path =>
checkParquet(table2.toString, path.getAbsolutePath)
}
}
}

test("SPARK-49211: V2 Catalog support CTAS") {
def checkCTAS(tableName: String, path: String): Unit = {
sql("CREATE TABLE " + tableName + " USING PARQUET LOCATION '" + path +
"' AS SELECT 1, 2, 3")
checkAnswer(sql("SELECT * FROM " + tableName), Row(1, 2, 3))
}

// Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can
// configure a new implementation.
spark.sessionState.catalogManager.reset()
val table1 = FullQualifiedTableName(SESSION_CATALOG_NAME, "default", "t")
withSQLConf(
V2_SESSION_CATALOG_IMPLEMENTATION.key ->
classOf[V2CatalogSupportBuiltinDataSource].getName) {
withTempPath { path =>
checkCTAS(table1.toString, path.getAbsolutePath)
}
}

val table2 = FullQualifiedTableName("testcat3", "default", "t")
withSQLConf(
"spark.sql.catalog.testcat3" -> classOf[V2CatalogSupportBuiltinDataSource].getName) {
withTempPath { path =>
checkCTAS(table2.toString, path.getAbsolutePath)
}
}
}

private def testNotSupportedV2Command(
sqlCommand: String,
sqlParams: String,
Expand Down Expand Up @@ -3673,3 +3738,36 @@ class SimpleDelegatingCatalog extends DelegatingCatalogExtension {
super.createTable(ident, columns, partitions, newProps)
}
}


class V2CatalogSupportBuiltinDataSource extends InMemoryCatalog {
override def createTable(
ident: Identifier,
columns: Array[ColumnV2],
partitions: Array[Transform],
properties: util.Map[String, String]): Table = {
super.createTable(ident, columns, partitions, properties)
null
}

override def loadTable(ident: Identifier): Table = {
val superTable = super.loadTable(ident)
val tableIdent = {
TableIdentifier(ident.name(), Some(ident.namespace().head), Some(name))
}
val uri = CatalogUtils.stringToURI(superTable.properties().get(TableCatalog.PROP_LOCATION))
val sparkTable = CatalogTable(
tableIdent,
tableType = CatalogTableType.EXTERNAL,
storage = CatalogStorageFormat.empty.copy(
locationUri = Some(uri),
properties = superTable.properties().asScala.toMap
),
schema = superTable.schema(),
provider = Some(superTable.properties().get(TableCatalog.PROP_PROVIDER)),
tracksPartitionsInCatalog = false
)
V1Table(sparkTable)
}
}

Loading

0 comments on commit b9fbdf0

Please sign in to comment.