From d604c1daccad5ff8a3154950f92fd2b5b383485e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 26 Nov 2024 13:47:41 +0100 Subject: [PATCH] Address comments v2 --- .../org/apache/spark/sql/SparkSession.scala | 20 +++++++--- .../sql/scripting/SqlScriptingExecution.scala | 37 ++++++++++++------- .../scripting/SqlScriptingExecutionNode.scala | 4 +- .../scripting/SqlScriptingInterpreter.scala | 5 ++- .../sql/scripting/SqlScriptingE2eSuite.scala | 29 ++++++++++++++- .../SqlScriptingExecutionSuite.scala | 2 + 6 files changed, 73 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index dd3b8ce4967ef..20d0bddcbfd0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -413,6 +413,14 @@ class SparkSession private( | Everything else | * ----------------- */ + /** + * Executes given script and return the result of the last statement. + * + * @param script A SQL script to execute. + * @param args A map of parameter names to SQL literal expressions. + * + * @return The result as a `DataFrame`. + */ private def executeSqlScript( script: CompoundBody, args: Map[String, Expression] = Map.empty): DataFrame = { @@ -425,7 +433,7 @@ class SparkSession private( if (sse.hasNext) { df.write.format("noop").mode("overwrite").save() } else { - // Collect results from the last DataFrame + // Collect results from the last DataFrame. result = Some(df.collect().toSeq) } } @@ -462,7 +470,7 @@ class SparkSession private( parsedPlan match { case compoundBody: CompoundBody => if (args.nonEmpty) { - // Positional parameters are not supported for SQL scripting + // Positional parameters are not supported for SQL scripting. throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting() } compoundBody @@ -477,10 +485,10 @@ class SparkSession private( plan match { case compoundBody: CompoundBody => - // execute the SQL script + // Execute the SQL script. executeSqlScript(compoundBody) case logicalPlan: LogicalPlan => - // execute the standalone SQL statement + // Execute the standalone SQL statement. Dataset.ofRows(self, plan, tracker) } } @@ -528,10 +536,10 @@ class SparkSession private( plan match { case compoundBody: CompoundBody => - // execute the SQL script + // Execute the SQL script. executeSqlScript(compoundBody, args.transform((_, v) => lit(v).expr)) case logicalPlan: LogicalPlan => - // execute the standalone SQL statement + // Execute the standalone SQL statement. Dataset.ofRows(self, plan, tracker) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 427d3bbf8d04a..d124e3d484c71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -17,19 +17,25 @@ package org.apache.spark.sql.scripting +import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody, MultiResult} /** * SQL scripting executor - executes script and returns result statements. + * This supports returning multiple result statements from a single script. + * + * @param sqlScript CompoundBody which need to be executed. + * @param session Spark session that SQL script is executed within. + * @param args A map of parameter names to SQL literal expressions. */ class SqlScriptingExecution( sqlScript: CompoundBody, session: SparkSession, args: Map[String, Expression]) extends Iterator[DataFrame] { - // Build the execution plan for the script + // Build the execution plan for the script. private val executionPlan: Iterator[CompoundStatementExec] = SqlScriptingInterpreter(session).buildExecutionPlan(sqlScript, args) @@ -38,38 +44,41 @@ class SqlScriptingExecution( override def hasNext: Boolean = current.isDefined override def next(): DataFrame = { - if (!hasNext) { - throw new NoSuchElementException("No more statements to execute") + if (!hasNext) {throw SparkException.internalError( + "No more elements to iterate through.") } val nextDataFrame = current.get current = getNextResult nextDataFrame } - /** Helper method to iterate through statements until next result statement is encountered */ + /** Helper method to iterate through statements until next result statement is encountered. */ private def getNextResult: Option[DataFrame] = { - var currentStatement = if (executionPlan.hasNext) Some(executionPlan.next()) else None - // While we don't have a result statement, execute the statements + + def getNextStatement: Option[CompoundStatementExec] = + if (executionPlan.hasNext) Some(executionPlan.next()) else None + + var currentStatement = getNextStatement + // While we don't have a result statement, execute the statements. while (currentStatement.isDefined) { currentStatement match { case Some(stmt: SingleStatementExec) if !stmt.isExecuted => withErrorHandling { val df = stmt.buildDataFrame(session) - if (!df.logicalPlan.isInstanceOf[CommandResult] - && !df.logicalPlan.isInstanceOf[MultiResult]) { - // If the statement is a result, we need to return it to the caller - return Some(df) + df.logicalPlan match { + case _: CommandResult | _: MultiResult => // pass + case _ => return Some(df) // If the statement is a result, return it to the caller. } } case _ => // pass } - currentStatement = if (executionPlan.hasNext) Some(executionPlan.next()) else None + currentStatement = getNextStatement } None } - private def handleException(e: Exception): Unit = { - // Rethrow the exception + private def handleException(e: Throwable): Unit = { + // Rethrow the exception. // TODO: SPARK-48353 Add error handling for SQL scripts throw e } @@ -78,7 +87,7 @@ class SqlScriptingExecution( try { f } catch { - case e: Exception => + case e: Throwable => handleException(e) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 8bda29a22be34..94284ec514f55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -107,6 +107,8 @@ trait NonLeafStatementExec extends CompoundStatementExec { * Logical plan of the parsed statement. * @param origin * Origin descriptor for the statement. + * @param args + * A map of parameter names to SQL literal expressions. * @param isInternal * Whether the statement originates from the SQL script or it is created during the * interpretation. Example: DropVariable statements are automatically created at the end of each @@ -148,7 +150,7 @@ class SingleStatementExec( /** * Builds a DataFrame from the parsedPlan of this SingleStatementExec - * @param session The SparkSession on which the parsedPlan is built + * @param session The SparkSession on which the parsedPlan is built. * @return * The DataFrame. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 2a395435f8af2..387ae36b881f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -37,6 +37,8 @@ case class SqlScriptingInterpreter(session: SparkSession) { * * @param compound * CompoundBody for which to build the plan. + * @param args + * A map of parameter names to SQL literal expressions. * @return * Iterator through collection of statements to be executed. */ @@ -65,6 +67,8 @@ case class SqlScriptingInterpreter(session: SparkSession) { * * @param node * Root node of the parsed tree. + * @param args + * A map of parameter names to SQL literal expressions. * @return * Executable statement. */ @@ -102,7 +106,6 @@ case class SqlScriptingInterpreter(session: SparkSession) { case CaseStatement(conditions, conditionalBodies, elseBody) => val conditionsExec = conditions.map(condition => - // todo: what to put here for isInternal, in case of simple case statement new SingleStatementExec( condition.parsedPlan, condition.origin, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala index 447ba98127876..afcdfd343e33b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkConf -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.CompoundBody import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLConf import org.apache.spark.sql.exceptions.SqlScriptingException @@ -151,7 +151,7 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession { | SELECT ?; | END IF; |END""".stripMargin - // Define an array with SQL parameters in the correct order + // Define an array with SQL parameters in the correct order. val args: Array[Any] = Array(5, "greater", "smaller") checkError( exception = intercept[SqlScriptingException] { @@ -160,4 +160,29 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession { condition = "UNSUPPORTED_FEATURE.SQL_SCRIPTING_WITH_POSITIONAL_PARAMETERS", parameters = Map.empty) } + + test("named params with positional params - should fail") { + val sqlScriptText = + """ + |BEGIN + | SELECT ?; + | IF :param > 10 THEN + | SELECT 1; + | ELSE + | SELECT 2; + | END IF; + |END""".stripMargin + // Define a map with SQL parameters. + val args: Map[String, Any] = Map("param" -> 5) + checkError( + exception = intercept[AnalysisException] { + spark.sql(sqlScriptText, args).asInstanceOf[CompoundBody] + }, + condition = "UNBOUND_SQL_PARAMETER", + parameters = Map("name" -> "_16"), + context = ExpectedContext( + fragment = "?", + start = 16, + stop = 16)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index 140fa495447e2..bbeae942f9fe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -1024,6 +1024,8 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END LOOP; | END LOOP; |END""".stripMargin + // Execution immediately leaves the outer loop after SELECT, + // so we expect only a single row in the result set. val expected = Seq(Seq(Row(1))) verifySqlScriptResult(sqlScriptText, expected) }