Skip to content

Commit

Permalink
Fetch globalTempDatabase name directly without invoking initializatio…
Browse files Browse the repository at this point in the history
…n of GlobalaTempViewManager.
  • Loading branch information
willwwt committed Jun 7, 2024
1 parent ce1b08f commit ec391c7
Show file tree
Hide file tree
Showing 22 changed files with 56 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
*
* @param database The system preserved virtual database that keeps all the global temporary views.
*/
class GlobalTempViewManager(val database: String) {
class GlobalTempViewManager(database: String) {

/** List of view definitions, mapping from view name to logical plan. */
@GuardedBy("this")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class SessionCatalog(

lazy val externalCatalog = externalCatalogBuilder()
lazy val globalTempViewManager = globalTempViewManagerBuilder()
val globalTempDatabase: String = SQLConf.get.globalTempDatabase

/** List of temporary views, mapping from table name to their logical plan. */
@GuardedBy("this")
Expand Down Expand Up @@ -273,9 +274,9 @@ class SessionCatalog(

def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {
val dbName = format(dbDefinition.name)
if (dbName == globalTempViewManager.database) {
if (dbName == globalTempDatabase) {
throw QueryCompilationErrors.cannotCreateDatabaseWithSameNameAsPreservedDatabaseError(
globalTempViewManager.database)
globalTempDatabase)
}
validateName(dbName)
externalCatalog.createDatabase(
Expand Down Expand Up @@ -333,9 +334,9 @@ class SessionCatalog(

def setCurrentDatabase(db: String): Unit = {
val dbName = format(db)
if (dbName == globalTempViewManager.database) {
if (dbName == globalTempDatabase) {
throw QueryCompilationErrors.cannotUsePreservedDatabaseAsCurrentDatabaseError(
globalTempViewManager.database)
globalTempDatabase)
}
requireDbExists(dbName)
synchronized { currentDb = dbName }
Expand Down Expand Up @@ -659,7 +660,7 @@ class SessionCatalog(
} else {
false
}
} else if (format(name.database.get) == globalTempViewManager.database) {
} else if (format(name.database.get) == globalTempDatabase) {
globalTempViewManager.update(viewName, viewDefinition)
} else {
false
Expand Down Expand Up @@ -767,9 +768,9 @@ class SessionCatalog(
val table = format(name.table)
if (name.database.isEmpty) {
tempViews.get(table).map(_.tableMeta).getOrElse(getTableMetadata(name))
} else if (format(name.database.get) == globalTempViewManager.database) {
} else if (format(name.database.get) == globalTempDatabase) {
globalTempViewManager.get(table).map(_.tableMeta)
.getOrElse(throw new NoSuchTableException(globalTempViewManager.database, table))
.getOrElse(throw new NoSuchTableException(globalTempDatabase, table))
} else {
getTableMetadata(name)
}
Expand All @@ -795,7 +796,7 @@ class SessionCatalog(

val oldTableName = qualifiedIdent.table
val newTableName = format(newName.table)
if (db == globalTempViewManager.database) {
if (db == globalTempDatabase) {
globalTempViewManager.rename(oldTableName, newTableName)
} else {
requireDbExists(db)
Expand Down Expand Up @@ -832,10 +833,10 @@ class SessionCatalog(
val qualifiedIdent = qualifyIdentifier(name)
val db = qualifiedIdent.database.get
val table = qualifiedIdent.table
if (db == globalTempViewManager.database) {
if (db == globalTempDatabase) {
val viewExists = globalTempViewManager.remove(table)
if (!viewExists && !ignoreIfNotExists) {
throw new NoSuchTableException(globalTempViewManager.database, table)
throw new NoSuchTableException(globalTempDatabase, table)
}
} else {
if (name.database.isDefined || !tempViews.contains(table)) {
Expand Down Expand Up @@ -873,7 +874,7 @@ class SessionCatalog(
val qualifiedIdent = qualifyIdentifier(name)
val db = qualifiedIdent.database.get
val table = qualifiedIdent.table
if (db == globalTempViewManager.database) {
if (db == globalTempDatabase) {
globalTempViewManager.get(table).map { viewDef =>
SubqueryAlias(table, db, getTempViewPlan(viewDef))
}.getOrElse(throw new NoSuchTableException(db, table))
Expand Down Expand Up @@ -1026,7 +1027,7 @@ class SessionCatalog(
}

def isGlobalTempViewDB(dbName: String): Boolean = {
globalTempViewManager.database.equalsIgnoreCase(dbName)
globalTempDatabase.equalsIgnoreCase(dbName)
}

/**
Expand Down Expand Up @@ -1085,9 +1086,9 @@ class SessionCatalog(
pattern: String,
includeLocalTempViews: Boolean): Seq[TableIdentifier] = {
val dbName = format(db)
val dbTables = if (dbName == globalTempViewManager.database) {
val dbTables = if (dbName == globalTempDatabase) {
globalTempViewManager.listViewNames(pattern).map { name =>
TableIdentifier(name, Some(globalTempViewManager.database))
TableIdentifier(name, Some(globalTempDatabase))
}
} else {
requireDbExists(dbName)
Expand All @@ -1108,9 +1109,9 @@ class SessionCatalog(
*/
def listViews(db: String, pattern: String): Seq[TableIdentifier] = {
val dbName = format(db)
val dbViews = if (dbName == globalTempViewManager.database) {
val dbViews = if (dbName == globalTempDatabase) {
globalTempViewManager.listViewNames(pattern).map { name =>
TableIdentifier(name, Some(globalTempViewManager.database))
TableIdentifier(name, Some(globalTempDatabase))
}
} else {
requireDbExists(dbName)
Expand All @@ -1126,7 +1127,7 @@ class SessionCatalog(
* List all matching temp views in the specified database, including global/local temporary views.
*/
def listTempViews(db: String, pattern: String): Seq[CatalogTable] = {
val globalTempViews = if (format(db) == globalTempViewManager.database) {
val globalTempViews = if (format(db) == globalTempDatabase) {
globalTempViewManager.listViewNames(pattern).flatMap { viewName =>
globalTempViewManager.get(viewName).map(_.tableMeta)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5949,6 +5949,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def defaultDatabase: String = getConf(StaticSQLConf.CATALOG_DEFAULT_DATABASE)

def globalTempDatabase: String = getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE)

def allowsTempViewCreationWithMultipleNameparts: Boolean =
getConf(SQLConf.ALLOW_TEMP_VIEW_CREATION_WITH_MULTIPLE_NAME_PARTS)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -933,17 +933,17 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually {
createTempView(catalog, "temp_view4", tempTable, overrideIfExists = false)
createGlobalTempView(catalog, "global_temp_view1", tempTable, overrideIfExists = false)
createGlobalTempView(catalog, "global_temp_view2", tempTable, overrideIfExists = false)
assert(catalog.listTables(catalog.globalTempViewManager.database, "*").toSet ==
assert(catalog.listTables(catalog.globalTempDatabase, "*").toSet ==
Set(TableIdentifier("temp_view1"),
TableIdentifier("temp_view4"),
TableIdentifier("global_temp_view1", Some(catalog.globalTempViewManager.database)),
TableIdentifier("global_temp_view2", Some(catalog.globalTempViewManager.database))))
assert(catalog.listTables(catalog.globalTempViewManager.database, "*temp_view1").toSet ==
TableIdentifier("global_temp_view1", Some(catalog.globalTempDatabase)),
TableIdentifier("global_temp_view2", Some(catalog.globalTempDatabase))))
assert(catalog.listTables(catalog.globalTempDatabase, "*temp_view1").toSet ==
Set(TableIdentifier("temp_view1"),
TableIdentifier("global_temp_view1", Some(catalog.globalTempViewManager.database))))
assert(catalog.listTables(catalog.globalTempViewManager.database, "global*").toSet ==
Set(TableIdentifier("global_temp_view1", Some(catalog.globalTempViewManager.database)),
TableIdentifier("global_temp_view2", Some(catalog.globalTempViewManager.database))))
TableIdentifier("global_temp_view1", Some(catalog.globalTempDatabase))))
assert(catalog.listTables(catalog.globalTempDatabase, "global*").toSet ==
Set(TableIdentifier("global_temp_view1", Some(catalog.globalTempDatabase)),
TableIdentifier("global_temp_view2", Some(catalog.globalTempDatabase))))
}
}

Expand Down Expand Up @@ -1906,9 +1906,9 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually {
assert(catalog.getCachedTable(qualifiedName1) != null)

createGlobalTempView(catalog, "tbl2", Range(2, 10, 1, 10), false)
val qualifiedName2 = QualifiedTableName(catalog.globalTempViewManager.database, "tbl2")
val qualifiedName2 = QualifiedTableName(catalog.globalTempDatabase, "tbl2")
catalog.cacheTable(qualifiedName2, Range(2, 10, 1, 10))
catalog.refreshTable(TableIdentifier("tbl2", Some(catalog.globalTempViewManager.database)))
catalog.refreshTable(TableIdentifier("tbl2", Some(catalog.globalTempDatabase)))
assert(catalog.getCachedTable(qualifiedName2) != null)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ case class AnalyzeColumnCommand(
val sessionState = sparkSession.sessionState

tableIdent.database match {
case Some(db) if db == sparkSession.sharedState.globalTempViewManager.database =>
case Some(db) if db == sparkSession.sharedState.globalTempDB =>
val plan = sessionState.catalog.getGlobalTempView(tableIdent.identifier).getOrElse {
throw QueryCompilationErrors.noSuchTableError(db, tableIdent.identifier)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,12 @@ private[sql] class SharedState(
wrapped
}

val globalTempDB = conf.get(GLOBAL_TEMP_DATABASE)

/**
* A manager for global temporary views.
*/
lazy val globalTempViewManager: GlobalTempViewManager = {
val globalTempDB = conf.get(GLOBAL_TEMP_DATABASE)
if (externalCatalog.databaseExists(globalTempDB)) {
throw QueryExecutionErrors.databaseNameConflictWithSystemPreservedDatabaseError(globalTempDB)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
withSQLConf(SQLConf.STORE_ANALYZED_PLAN_FOR_VIEW.key -> storeAnalyzed.toString) {
withGlobalTempView("view1") {
withTempView("view2") {
val db = spark.sharedState.globalTempViewManager.database
val db = spark.sharedState.globalTempDB
sql("CREATE GLOBAL TEMPORARY VIEW view1 AS SELECT * FROM testData WHERE key > 1")
sql(s"CACHE TABLE view2 AS SELECT * FROM ${db}.view1 WHERE value > 1")
assert(spark.catalog.isCached("view2"))
Expand Down Expand Up @@ -1487,7 +1487,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
withSQLConf(SQLConf.STORE_ANALYZED_PLAN_FOR_VIEW.key -> storeAnalyzed.toString) {
withGlobalTempView("view1") {
withTempView("view2") {
val db = spark.sharedState.globalTempViewManager.database
val db = spark.sharedState.globalTempDB
sql("CREATE GLOBAL TEMPORARY VIEW view1 AS SELECT * FROM testData WHERE key > 1")
sql(s"CACHE TABLE view2 AS SELECT * FROM $db.view1 WHERE value > 1")
assert(spark.catalog.isCached("view2"))
Expand Down Expand Up @@ -1517,7 +1517,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
Seq(true, false).foreach { storeAnalyzed =>
withSQLConf(SQLConf.STORE_ANALYZED_PLAN_FOR_VIEW.key -> storeAnalyzed.toString) {
withGlobalTempView("global_tv") {
val db = spark.sharedState.globalTempViewManager.database
val db = spark.sharedState.globalTempDB
testAlterTemporaryViewAsWithCache(TableIdentifier("global_tv", Some(db)), storeAnalyzed)
}
}
Expand Down Expand Up @@ -1575,7 +1575,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils

test("SPARK-34699: CREATE GLOBAL TEMP VIEW USING should uncache correctly") {
withGlobalTempView("global_tv") {
val db = spark.sharedState.globalTempViewManager.database
val db = spark.sharedState.globalTempDB
testCreateTemporaryViewUsingWithCache(TableIdentifier("global_tv", Some(db)))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared

test("analyzes column statistics in cached global temporary view") {
withGlobalTempView("gTempView") {
val globalTempDB = spark.sharedState.globalTempViewManager.database
val globalTempDB = spark.sharedState.globalTempDB
val e1 = intercept[AnalysisException] {
sql(s"ANALYZE TABLE $globalTempDB.gTempView COMPUTE STATISTICS FOR COLUMNS id")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class GlobalTempViewSuite extends QueryTest with SharedSparkSession {

override protected def beforeAll(): Unit = {
super.beforeAll()
globalTempDB = spark.sharedState.globalTempViewManager.database
globalTempDB = spark.sharedState.globalTempDB
}

private var globalTempDB: String = _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
"objName" -> s"`$SESSION_CATALOG_NAME`.`default`.`jtv1`",
"tempObj" -> "VIEW",
"tempObjName" -> "`temp_jtv1`"))
val globalTempDB = spark.sharedState.globalTempViewManager.database
val globalTempDB = spark.sharedState.globalTempDB
sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0")
checkError(
exception = intercept[AnalysisException] {
Expand Down Expand Up @@ -1102,7 +1102,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
test("local temp view refers global temp view") {
withGlobalTempView("v1") {
withTempView("v2") {
val globalTempDB = spark.sharedState.globalTempViewManager.database
val globalTempDB = spark.sharedState.globalTempDB
sql("CREATE GLOBAL TEMPORARY VIEW v1 AS SELECT 1")
sql(s"CREATE TEMPORARY VIEW v2 AS SELECT * FROM ${globalTempDB}.v1")
checkAnswer(sql("SELECT * FROM v2"), Seq(Row(1)))
Expand All @@ -1113,7 +1113,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
test("global temp view refers local temp view") {
withTempView("v1") {
withGlobalTempView("v2") {
val globalTempDB = spark.sharedState.globalTempViewManager.database
val globalTempDB = spark.sharedState.globalTempDB
sql("CREATE TEMPORARY VIEW v1 AS SELECT 1")
sql(s"CREATE GLOBAL TEMPORARY VIEW v2 AS SELECT * FROM v1")
checkAnswer(sql(s"SELECT * FROM ${globalTempDB}.v2"), Seq(Row(1)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ class LocalTempViewTestSuite extends TempViewTestSuite with SharedSparkSession {
}

class GlobalTempViewTestSuite extends TempViewTestSuite with SharedSparkSession {
private def db: String = spark.sharedState.globalTempViewManager.database
private def db: String = spark.sharedState.globalTempDB
override protected def viewTypeString: String = "GLOBAL TEMPORARY VIEW"
override protected def formattedViewName(viewName: String): String = {
s"$db.$viewName"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ trait AlterTableDropPartitionSuiteBase extends QueryTest with DDLCommandTestUtil
checkCachedRelation("v1", Seq(Row(0, 0), Row(3, 3)))
}

val v2 = s"${spark.sharedState.globalTempViewManager.database}.v2"
val v2 = s"${spark.sharedState.globalTempDB}.v2"
withGlobalTempView("v2") {
sql(s"CREATE GLOBAL TEMP VIEW v2 AS SELECT * FROM $t")
cacheRelation(v2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ trait AlterTableRenamePartitionSuiteBase extends QueryTest with DDLCommandTestUt
checkCachedRelation("v1", Seq(Row(0, 2), Row(1, 3)))
}

val v2 = s"${spark.sharedState.globalTempViewManager.database}.v2"
val v2 = s"${spark.sharedState.globalTempDB}.v2"
withGlobalTempView("v2") {
sql(s"CREATE GLOBAL TEMP VIEW v2 AS SELECT * FROM $t")
cacheRelation(v2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2260,7 +2260,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase {
)

withGlobalTempView("src") {
val globalTempDB = spark.sharedState.globalTempViewManager.database
val globalTempDB = spark.sharedState.globalTempDB
sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b")
sql(s"CREATE TABLE t4 LIKE $globalTempDB.src USING parquet")
val table = catalog.getTableMetadata(TableIdentifier("t4"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ trait TruncateTableSuiteBase extends QueryTest with DDLCommandTestUtils {
)
}

val v2 = s"${spark.sharedState.globalTempViewManager.database}.v2"
val v2 = s"${spark.sharedState.globalTempDB}.v2"
withGlobalTempView("v2") {
sql(s"CREATE GLOBAL TEMP VIEW v2 AS SELECT * FROM $t")
checkError(
Expand Down Expand Up @@ -245,7 +245,7 @@ trait TruncateTableSuiteBase extends QueryTest with DDLCommandTestUtils {
checkCachedRelation("v1", Seq(Row(0, 0, 0)))
}

val v2 = s"${spark.sharedState.globalTempViewManager.database}.v2"
val v2 = s"${spark.sharedState.globalTempDB}.v2"
withGlobalTempView("v2") {
sql(s"INSERT INTO $t PARTITION (width = 10, length = 10) SELECT 10")
sql(s"CREATE GLOBAL TEMP VIEW v2 AS SELECT * FROM $t")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ trait AlterTableAddPartitionSuiteBase extends command.AlterTableAddPartitionSuit
checkCachedRelation("v1", Seq(Row(0, 0), Row(0, 1), Row(0, 2)))
}

val v2 = s"${spark.sharedState.globalTempViewManager.database}.v2"
val v2 = s"${spark.sharedState.globalTempDB}.v2"
withGlobalTempView("v2") {
sql(s"CREATE GLOBAL TEMP VIEW v2 AS SELECT * FROM $t")
cacheRelation(v2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class AlterTableAddPartitionSuite
checkCachedRelation("v1", Seq(Row(0, 0), Row(0, 1), Row(1, 2)))
}

val v2 = s"${spark.sharedState.globalTempViewManager.database}.v2"
val v2 = s"${spark.sharedState.globalTempDB}.v2"
withGlobalTempView(v2) {
sql(s"CREATE GLOBAL TEMP VIEW v2 AS SELECT * FROM $t")
cacheRelation(v2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ private[hive] class SparkGetColumnsOperation(
}

// Global temporary views
val globalTempViewDb = catalog.globalTempViewManager.database
val globalTempViewDb = catalog.globalTempDatabase
val databasePattern = Pattern.compile(CLIServiceUtils.patternToRegex(schemaName))
if (databasePattern.matcher(globalTempViewDb).matches()) {
catalog.globalTempViewManager.listViewNames(tablePattern).foreach { globalTempView =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private[hive] class SparkGetSchemasOperation(
rowSet.addRow(Array[AnyRef](dbName, DEFAULT_HIVE_CATALOG))
}

val globalTempViewDb = sqlContext.sessionState.catalog.globalTempViewManager.database
val globalTempViewDb = sqlContext.sessionState.catalog.globalTempDatabase
val databasePattern = Pattern.compile(CLIServiceUtils.patternToRegex(schemaName))
if (schemaName == null || schemaName.isEmpty ||
databasePattern.matcher(globalTempViewDb).matches()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ private[hive] class SparkGetTablesOperation(

// Temporary views and global temporary views
if (tableTypes == null || tableTypes.isEmpty || tableTypes.contains(VIEW.name)) {
val globalTempViewDb = catalog.globalTempViewManager.database
val globalTempViewDb = catalog.globalTempDatabase
val databasePattern = Pattern.compile(CLIServiceUtils.patternToRegex(schemaName))
val tempViews = if (databasePattern.matcher(globalTempViewDb).matches()) {
catalog.listTables(globalTempViewDb, tablePattern, includeLocalTempViews = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer {

test("SPARK-29911: Uncache cached tables when session closed") {
val cacheManager = spark.sharedState.cacheManager
val globalTempDB = spark.sharedState.globalTempViewManager.database
val globalTempDB = spark.sharedState.globalTempDB
withJdbcStatement() { statement =>
statement.execute("CACHE TABLE tempTbl AS SELECT 1")
}
Expand Down
Loading

0 comments on commit ec391c7

Please sign in to comment.