diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0bb4fc9c90d8a..038f15ee11035 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -468,7 +468,7 @@ class AstBuilder extends DataTypeAstBuilder val (relationCtx, options, cols, partition, ifPartitionNotExists, byName) = visitInsertIntoTable(table) withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => { - val insertIntoStatement = InsertIntoStatement( + InsertIntoStatement( createUnresolvedRelation(relationCtx, ident, options), partition, cols, @@ -476,11 +476,6 @@ class AstBuilder extends DataTypeAstBuilder overwrite = false, ifPartitionNotExists, byName) - if (conf.getConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER)) { - EvaluateUnresolvedInlineTable.evaluate(insertIntoStatement) - } else { - insertIntoStatement - } }) case table: InsertOverwriteTableContext => val (relationCtx, options, cols, partition, ifPartitionNotExists, byName) @@ -1897,7 +1892,12 @@ class AstBuilder extends DataTypeAstBuilder Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } - val table = UnresolvedInlineTable(aliases, rows.toSeq) + val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq) + val table = if (conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) + } else { + unresolvedTable + } table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/EvaluateUnresolvedInlineTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/EvaluateUnresolvedInlineTable.scala index a55f70c238a8a..51cab6bff3b03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/EvaluateUnresolvedInlineTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/EvaluateUnresolvedInlineTable.scala @@ -35,17 +35,8 @@ import org.apache.spark.sql.types.{StructField, StructType} object EvaluateUnresolvedInlineTable extends SQLConfHelper with AliasHelper with EvalHelper with CastSupport { - def evaluate(plan: LogicalPlan): LogicalPlan = { - traversePlanAndEvalUnresolvedInlineTable(plan) - } - - def traversePlanAndEvalUnresolvedInlineTable(plan: LogicalPlan): LogicalPlan = { - plan match { - case table: UnresolvedInlineTable if table.expressionsResolved => - evaluateUnresolvedInlineTable(table) - case _ => plan.mapChildren(traversePlanAndEvalUnresolvedInlineTable) - } - } + def evaluate(plan: UnresolvedInlineTable): LogicalPlan = + if (plan.expressionsResolved) evaluateUnresolvedInlineTable(plan) else plan def evaluateUnresolvedInlineTable(table: UnresolvedInlineTable): LogicalPlan = { validateInputDimension(table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 935111387745e..096cc974fbe6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -969,11 +969,11 @@ object SQLConf { .booleanConf .createWithDefault(true) - val OPTIMIZE_INSERT_INTO_VALUES_PARSER = - buildConf("spark.sql.parser.optimizeInsertIntoValuesParser") + val EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED = + buildConf("spark.sql.parser.eagerEvalOfUnresolvedInlineTable") .internal() .doc("Controls whether we optimize the ASTree that gets generated when parsing " + - "`insert into ... values` DML statements.") + "VALUES lists (UnresolvedInlineTable) by eagerly evaluating it in the AST Builder.") .booleanConf .createWithDefault(true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 59602a4c77d08..c930292f2793c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -2633,7 +2633,7 @@ class DDLParserSuite extends AnalysisTest { for (optimizeInsertIntoValues <- Seq(true, false)) { withSQLConf( - SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key -> + SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> optimizeInsertIntoValues.toString) { comparePlans(parsePlan(dateTypeSql), insertPartitionPlan( "2019-01-02", optimizeInsertIntoValues)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 8d01040563361..e0217a5637a81 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.parser import scala.annotation.nowarn import org.apache.spark.SparkThrowable -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{EvaluateUnresolvedInlineTable, FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedParameter, PosParameter, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction, UnresolvedTVFAliases} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -1000,14 +1000,28 @@ class PlanParserSuite extends AnalysisTest { } test("inline table") { - assertEqual("values 1, 2, 3, 4", - UnresolvedInlineTable(Seq("col1"), Seq(1, 2, 3, 4).map(x => Seq(Literal(x))))) + for (optimizeValues <- Seq(true, false)) { + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + optimizeValues.toString) { + val unresolvedTable1 = + UnresolvedInlineTable(Seq("col1"), Seq(1, 2, 3, 4).map(x => Seq(Literal(x)))) + val table1 = if (optimizeValues) { + EvaluateUnresolvedInlineTable.evaluate(unresolvedTable1) + } else { + unresolvedTable1 + } + assertEqual("values 1, 2, 3, 4", table1) - assertEqual( - "values (1, 'a'), (2, 'b') as tbl(a, b)", - UnresolvedInlineTable( - Seq("a", "b"), - Seq(Literal(1), Literal("a")) :: Seq(Literal(2), Literal("b")) :: Nil).as("tbl")) + val unresolvedTable2 = UnresolvedInlineTable( + Seq("a", "b"), Seq(Literal(1), Literal("a")) :: Seq(Literal(2), Literal("b")) :: Nil) + val table2 = if (optimizeValues) { + EvaluateUnresolvedInlineTable.evaluate(unresolvedTable2) + } else { + unresolvedTable2 + } + assertEqual("values (1, 'a'), (2, 'b') as tbl(a, b)", table2.as("tbl")) + } + } } test("simple select query with !> and !<") { @@ -1907,12 +1921,22 @@ class PlanParserSuite extends AnalysisTest { } test("SPARK-42553: NonReserved keyword 'interval' can be column name") { - comparePlans( - parsePlan("SELECT interval FROM VALUES ('abc') AS tbl(interval);"), - UnresolvedInlineTable( - Seq("interval"), - Seq(Literal("abc")) :: Nil).as("tbl").select($"interval") - ) + for (optimizeValues <- Seq(true, false)) { + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + optimizeValues.toString) { + val unresolvedTable = + UnresolvedInlineTable(Seq("interval"), Seq(Literal("abc")) :: Nil) + val table = if (optimizeValues) { + EvaluateUnresolvedInlineTable.evaluate(unresolvedTable) + } else { + unresolvedTable + } + comparePlans( + parsePlan("SELECT interval FROM VALUES ('abc') AS tbl(interval);"), + table.as("tbl").select($"interval") + ) + } + } } test("SPARK-44066: parsing of positional parameters") { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out index 988df7de1a3cf..78539effe188e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/inline-table.sql.out @@ -115,7 +115,7 @@ org.apache.spark.sql.AnalysisException -- !query select * from values ("one", 2.0), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -157,7 +157,7 @@ org.apache.spark.sql.AnalysisException -- !query select * from values ("one"), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out index 2333cce874d31..f042116182f7d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/window_part4.sql.out @@ -498,7 +498,7 @@ SELECT a, b, SUM(b) OVER(ORDER BY A ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-inline-table.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-inline-table.sql.out index fb6130be5b6b4..786b5ac49b126 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-inline-table.sql.out @@ -101,7 +101,7 @@ org.apache.spark.sql.AnalysisException -- !query select udf(a), udf(b) from values ("one", 2.0), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -143,7 +143,7 @@ org.apache.spark.sql.AnalysisException -- !query select udf(a), udf(b) from values ("one"), ("two") as data(a, b) -- !query analysis -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index 4dcdf8ac3e980..0a2c7b0f55ed2 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -131,7 +131,7 @@ select * from values ("one", 2.0), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -177,7 +177,7 @@ select * from values ("one"), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out index 2085186dc8cfa..2d539725b2a70 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out @@ -497,7 +497,7 @@ FROM (VALUES(1,1),(2,2),(3,(cast('nan' as int))),(4,3),(5,4)) t(a,b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.FAILED_SQL_EXPRESSION_EVALUATION", "sqlState" : "42000", diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out index d09f56a836788..3e84ec09c2150 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-inline-table.sql.out @@ -115,7 +115,7 @@ select udf(a), udf(b) from values ("one", 2.0), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", @@ -161,7 +161,7 @@ select udf(a), udf(b) from values ("one"), ("two") as data(a, b) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { "errorClass" : "INVALID_INLINE_TABLE.NUM_COLUMNS_MISMATCH", "sqlState" : "42000", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InlineTableParsingImprovementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InlineTableParsingImprovementsSuite.scala index f305670dded8d..8c776874eaa1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InlineTableParsingImprovementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InlineTableParsingImprovementsSuite.scala @@ -19,11 +19,19 @@ package org.apache.spark.sql import java.util.UUID +import org.apache.spark.sql.catalyst.analysis.UnresolvedInlineTable +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSession { + /** + * SQL parser. + */ + private lazy val parser = spark.sessionState.sqlParser + /** * Generate a random table name. */ @@ -59,7 +67,14 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess */ private def generateInsertStatementWithLiterals(tableName: String, numRows: Int): String = { val baseQuery = s"INSERT INTO $tableName (id, first_name, last_name, age, gender," + - s" email, phone_number, address, city, state, zip_code, country, registration_date) VALUES " + s" email, phone_number, address, city, state, zip_code, country, registration_date) " + baseQuery + generateValuesWithLiterals(numRows) + ";" + } + + /** + * Generate a VALUES clause with the given number of rows using basic literals. + */ + private def generateValuesWithLiterals(numRows: Int = 10): String = { val rows = (1 to numRows).map { i => val id = i val firstName = s"'FirstName_$id'" @@ -79,7 +94,33 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess s" $address, $city, $state, $zipCode, $country, $registrationDate)" }.mkString(",\n") - baseQuery + rows + ";" + s" VALUES $rows" + } + + /** + * Traverse the plan and check for the presence of the given node type. + */ + private def traversePlanAndCheckForNodeType[T <: LogicalPlan]( + plan: LogicalPlan, nodeType: Class[T]): Boolean = plan match { + case node if nodeType.isInstance(node) => true + case n: Project => + // If the plan node is a Project, we need to check the expressions in the project list + // and the child nodes. + n.projectList.exists(traverseExpressionAndCheckForNodeType(_, nodeType)) || + n.children.exists(traversePlanAndCheckForNodeType(_, nodeType)) + case node if node.children.isEmpty => false + case _ => plan.children.exists(traversePlanAndCheckForNodeType(_, nodeType)) + } + + /** + * Traverse the expression and check for the presence of the given node type. + */ + private def traverseExpressionAndCheckForNodeType[T <: LogicalPlan]( + expression: Expression, nodeType: Class[T]): Boolean = expression match { + case scalarSubquery: ScalarSubquery => scalarSubquery.plan.exists( + traversePlanAndCheckForNodeType(_, nodeType)) + case _ => + expression.children.exists(traverseExpressionAndCheckForNodeType(_, nodeType)) } /** @@ -87,54 +128,64 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess */ private def generateInsertStatementsWithComplexExpressions( tableName: String): String = { - s""" - INSERT INTO $tableName (id, first_name, last_name, age, gender, - email, phone_number, address, city, state, zip_code, country, registration_date) VALUES - - (1, base64('FirstName_1'), base64('LastName_1'), 10+10, 'M', 'usr' || '@gmail.com', - concat('555','-1234'), hex('123 Fake St'), 'Anytown', 'CA', '12345', 'USA', - '2021-01-01'), + s""" + INSERT INTO $tableName (id, first_name, last_name, age, gender, + email, phone_number, address, city, state, zip_code, country, registration_date) VALUES + (1, base64('FirstName_1'), base64('LastName_1'), 10+10, 'M', 'usr' || '@gmail.com', + concat('555','-1234'), hex('123 Fake St'), 'Anytown', 'CA', '12345', 'USA', + '2021-01-01'), - (2, 'FirstName_2', string(5), abs(-8), 'F', 'usr@gmail.com', '555-1234', '123 Fake St', - concat('Anytown', 'sada'), 'CA', '12345', 'USA', '2021-01-01'), + (2, 'FirstName_2', string(5), abs(-8), 'F', 'usr@gmail.com', '555-1234', '123 Fake St', + concat('Anytown', 'sada'), 'CA', '12345', 'USA', '2021-01-01'), - (3, 'FirstName_3', 'LastName_3', 34::int, 'M', 'usr@gmail.com', '555-1234', - '123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01'), - - (4, left('FirstName_4', 5), upper('LastName_4'), acos(1), 'F', 'user@gmail.com', - '555-1234', '123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01'); - """ - } + (3, 'FirstName_3', 'LastName_3', 34::int, 'M', 'usr@gmail.com', '555-1234', + '123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01'), + (4, left('FirstName_4', 5), upper('LastName_4'), acos(1), 'F', 'user@gmail.com', + '555-1234', '123 Fake St', 'Anytown', 'CA', '12345', 'USA', '2021-01-01'); + """ + } test("Insert Into Values optimization - Basic literals.") { - // Set the number of inserted rows to 10000. - val rowCount = 10000 + // Set the number of inserted rows to 10. + val rowCount = 10 var firstTableName: Option[String] = None - Seq(true, false).foreach { insertIntoValueImprovementEnabled => + Seq(true, false).foreach { eagerEvalOfUnresolvedInlineTableEnabled => // Create a table with a randomly generated name. val tableName = createTable // Set the feature flag for the InsertIntoValues improvement. - withSQLConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key -> - insertIntoValueImprovementEnabled.toString) { + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + eagerEvalOfUnresolvedInlineTableEnabled.toString) { // Generate an INSERT INTO VALUES statement. val sqlStatement = generateInsertStatementWithLiterals(tableName, rowCount) + + // Parse the SQL statement. + val plan = parser.parsePlan(sqlStatement) + + // Traverse the plan and check for the presence of appropriate nodes depending on the + // feature flag. + if (eagerEvalOfUnresolvedInlineTableEnabled) { + assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation])) + } else { + assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable])) + } + spark.sql(sqlStatement) - // Double check that the insertion was successful. - val countStar = spark.sql(s"SELECT count(*) FROM $tableName").collect() - assert(countStar.head.getLong(0) == rowCount, - "The number of rows in the table should match the number of rows inserted.") + // Double check that the insertion was successful. + val countStar = spark.sql(s"SELECT count(*) FROM $tableName").collect() + assert(countStar.head.getLong(0) == rowCount, + "The number of rows in the table should match the number of rows inserted.") // Check that both insertions will produce equivalent tables. if (firstTableName.isEmpty) { firstTableName = Some(tableName) } else { - val df1 = spark.table(firstTableName.get) - val df2 = spark.table(tableName) - checkAnswer(df1, df2) + val df1 = spark.table(firstTableName.get) + val df2 = spark.table(tableName) + checkAnswer(df1, df2) } } } @@ -142,16 +193,27 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess test("Insert Into Values optimization - Basic literals & expressions.") { var firstTableName: Option[String] = None - Seq(true, false).foreach { insertIntoValueImprovementEnabled => + Seq(true, false).foreach { eagerEvalOfUnresolvedInlineTableEnabled => // Create a table with a randomly generated name. val tableName = createTable // Set the feature flag for the InsertIntoValues improvement. - withSQLConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key -> - insertIntoValueImprovementEnabled.toString) { + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + eagerEvalOfUnresolvedInlineTableEnabled.toString) { // Generate an INSERT INTO VALUES statement. val sqlStatement = generateInsertStatementsWithComplexExpressions(tableName) + + // Parse the SQL statement. + val plan = parser.parsePlan(sqlStatement) + + // Traverse the plan and check for the presence of appropriate nodes. + // In this case, the plan should always contain a UnresolvedInlineTable node + // because the expressions are not eagerly resolved, therefore + // `plan.expressionsResolved` in `EvaluateUnresolvedInlineTable.evaluate` will + // always be false. + assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable])) + spark.sql(sqlStatement) // Check that both insertions will produce equivalent tables. @@ -168,17 +230,30 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess test("Insert Into Values with defaults.") { var firstTableName: Option[String] = None - Seq(true, false).foreach { insertIntoValueImprovementEnabled => + Seq(true, false).foreach { eagerEvalOfUnresolvedInlineTableEnabled => // Create a table with default values specified. val tableName = createTable // Set the feature flag for the InsertIntoValues improvement. - withSQLConf(SQLConf.OPTIMIZE_INSERT_INTO_VALUES_PARSER.key -> - insertIntoValueImprovementEnabled.toString) { + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + eagerEvalOfUnresolvedInlineTableEnabled.toString) { // Generate an INSERT INTO VALUES statement that omits all columns // containing a DEFAULT value. - spark.sql(s"INSERT INTO $tableName (id) VALUES (1);") + val sqlStatement = s"INSERT INTO $tableName (id) VALUES (1);" + + // Parse the SQL statement. + val plan = parser.parsePlan(sqlStatement) + + // Traverse the plan and check for the presence of appropriate nodes depending on the + // feature flag. + if (eagerEvalOfUnresolvedInlineTableEnabled) { + assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation])) + } else { + assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable])) + } + + spark.sql(sqlStatement) // Verify that the default values are applied correctly. val resultRow = spark.sql( @@ -226,4 +301,72 @@ class InlineTableParsingImprovementsSuite extends QueryTest with SharedSparkSess } } } + + test("SPARK-49269: Value list in subquery") { + var firstDF: Option[DataFrame] = None + val flagVals = Seq(true, false) + flagVals.foreach { eagerEvalOfUnresolvedInlineTableEnabled => + // Set the feature flag for the InsertIntoValues improvement. + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + eagerEvalOfUnresolvedInlineTableEnabled.toString) { + + // Generate a subquery with a VALUES clause. + val sqlStatement = s"SELECT * FROM (${generateValuesWithLiterals()});" + + // Parse the SQL statement. + val plan = parser.parsePlan(sqlStatement) + + // Traverse the plan and check for the presence of appropriate nodes depending on the + // feature flag. + if (eagerEvalOfUnresolvedInlineTableEnabled) { + assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation])) + } else { + assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable])) + } + + val res = spark.sql(sqlStatement) + + // Check that both insertions will produce equivalent tables. + if (flagVals.head == eagerEvalOfUnresolvedInlineTableEnabled) { + firstDF = Some(res) + } else { + checkAnswer(res, firstDF.get) + } + } + } + } + + test("SPARK-49269: Value list in projection list subquery") { + var firstDF: Option[DataFrame] = None + val flagVals = Seq(true, false) + flagVals.foreach { eagerEvalOfUnresolvedInlineTableEnabled => + // Set the feature flag for the InsertIntoValues improvement. + withSQLConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED.key -> + eagerEvalOfUnresolvedInlineTableEnabled.toString) { + + // Generate a subquery with a VALUES clause in the projection list. + val sqlStatement = s"SELECT (SELECT COUNT(*) FROM ${generateValuesWithLiterals()});" + + // Parse the SQL statement. + val plan = parser.parsePlan(sqlStatement) + + // Traverse the plan and check for the presence of appropriate nodes depending on the + // feature flag. + if (eagerEvalOfUnresolvedInlineTableEnabled) { + assert(traversePlanAndCheckForNodeType(plan, classOf[LocalRelation])) + } else { + assert(traversePlanAndCheckForNodeType(plan, classOf[UnresolvedInlineTable])) + } + + val res = spark.sql(sqlStatement) + + // Check that both insertions will produce equivalent tables. + if (flagVals.head == eagerEvalOfUnresolvedInlineTableEnabled) { + firstDF = Some(res) + } else { + checkAnswer(res, firstDF.get) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala index a292afe6a7c28..bc42937b93a92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DeclareVariableParserSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.command +import org.apache.spark.sql.catalyst.EvaluateUnresolvedInlineTable import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedFunction, UnresolvedIdentifier, UnresolvedInlineTable} import org.apache.spark.sql.catalyst.expressions.{Add, Cast, Divide, Literal, ScalarSubquery} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, Project, SubqueryAlias} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{Decimal, DecimalType, DoubleType, IntegerType, MapType, NullType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -91,6 +93,13 @@ class DeclareVariableParserSuite extends AnalysisTest with SharedSparkSession { Cast(UnresolvedFunction("CURRENT_DATABASE", Nil, isDistinct = false), StringType), "CURRENT_DATABASE()"), replace = false)) + val subqueryAliasChild = + if (conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) { + EvaluateUnresolvedInlineTable.evaluate( + UnresolvedInlineTable(Seq("c1"), Seq(Literal(1)) :: Nil)) + } else { + UnresolvedInlineTable(Seq("c1"), Seq(Literal(1)) :: Nil) + } comparePlans( parsePlan("DECLARE VARIABLE var1 INT DEFAULT (SELECT c1 FROM VALUES(1) AS T(c1))"), CreateVariable( @@ -99,7 +108,7 @@ class DeclareVariableParserSuite extends AnalysisTest with SharedSparkSession { Cast(ScalarSubquery( Project(UnresolvedAttribute("c1") :: Nil, SubqueryAlias(Seq("T"), - UnresolvedInlineTable(Seq("c1"), Seq(Literal(1)) :: Nil)))), IntegerType), + subqueryAliasChild))), IntegerType), "(SELECT c1 FROM VALUES(1) AS T(c1))"), replace = false)) }