Skip to content

Commit

Permalink
Address comments v2
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Nov 26, 2024
1 parent 85c9cf8 commit d604c1d
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 24 deletions.
20 changes: 14 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,14 @@ class SparkSession private(
| Everything else |
* ----------------- */

/**
* Executes given script and return the result of the last statement.
*
* @param script A SQL script to execute.
* @param args A map of parameter names to SQL literal expressions.
*
* @return The result as a `DataFrame`.
*/
private def executeSqlScript(
script: CompoundBody,
args: Map[String, Expression] = Map.empty): DataFrame = {
Expand All @@ -425,7 +433,7 @@ class SparkSession private(
if (sse.hasNext) {
df.write.format("noop").mode("overwrite").save()
} else {
// Collect results from the last DataFrame
// Collect results from the last DataFrame.
result = Some(df.collect().toSeq)
}
}
Expand Down Expand Up @@ -462,7 +470,7 @@ class SparkSession private(
parsedPlan match {
case compoundBody: CompoundBody =>
if (args.nonEmpty) {
// Positional parameters are not supported for SQL scripting
// Positional parameters are not supported for SQL scripting.
throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
}
compoundBody
Expand All @@ -477,10 +485,10 @@ class SparkSession private(

plan match {
case compoundBody: CompoundBody =>
// execute the SQL script
// Execute the SQL script.
executeSqlScript(compoundBody)
case logicalPlan: LogicalPlan =>
// execute the standalone SQL statement
// Execute the standalone SQL statement.
Dataset.ofRows(self, plan, tracker)
}
}
Expand Down Expand Up @@ -528,10 +536,10 @@ class SparkSession private(

plan match {
case compoundBody: CompoundBody =>
// execute the SQL script
// Execute the SQL script.
executeSqlScript(compoundBody, args.transform((_, v) => lit(v).expr))
case logicalPlan: LogicalPlan =>
// execute the standalone SQL statement
// Execute the standalone SQL statement.
Dataset.ofRows(self, plan, tracker)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,25 @@

package org.apache.spark.sql.scripting

import org.apache.spark.SparkException
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody, MultiResult}

/**
* SQL scripting executor - executes script and returns result statements.
* This supports returning multiple result statements from a single script.
*
* @param sqlScript CompoundBody which need to be executed.
* @param session Spark session that SQL script is executed within.
* @param args A map of parameter names to SQL literal expressions.
*/
class SqlScriptingExecution(
sqlScript: CompoundBody,
session: SparkSession,
args: Map[String, Expression]) extends Iterator[DataFrame] {

// Build the execution plan for the script
// Build the execution plan for the script.
private val executionPlan: Iterator[CompoundStatementExec] =
SqlScriptingInterpreter(session).buildExecutionPlan(sqlScript, args)

Expand All @@ -38,38 +44,41 @@ class SqlScriptingExecution(
override def hasNext: Boolean = current.isDefined

override def next(): DataFrame = {
if (!hasNext) {
throw new NoSuchElementException("No more statements to execute")
if (!hasNext) {throw SparkException.internalError(
"No more elements to iterate through.")
}
val nextDataFrame = current.get
current = getNextResult
nextDataFrame
}

/** Helper method to iterate through statements until next result statement is encountered */
/** Helper method to iterate through statements until next result statement is encountered. */
private def getNextResult: Option[DataFrame] = {
var currentStatement = if (executionPlan.hasNext) Some(executionPlan.next()) else None
// While we don't have a result statement, execute the statements

def getNextStatement: Option[CompoundStatementExec] =
if (executionPlan.hasNext) Some(executionPlan.next()) else None

var currentStatement = getNextStatement
// While we don't have a result statement, execute the statements.
while (currentStatement.isDefined) {
currentStatement match {
case Some(stmt: SingleStatementExec) if !stmt.isExecuted =>
withErrorHandling {
val df = stmt.buildDataFrame(session)
if (!df.logicalPlan.isInstanceOf[CommandResult]
&& !df.logicalPlan.isInstanceOf[MultiResult]) {
// If the statement is a result, we need to return it to the caller
return Some(df)
df.logicalPlan match {
case _: CommandResult | _: MultiResult => // pass
case _ => return Some(df) // If the statement is a result, return it to the caller.
}
}
case _ => // pass
}
currentStatement = if (executionPlan.hasNext) Some(executionPlan.next()) else None
currentStatement = getNextStatement
}
None
}

private def handleException(e: Exception): Unit = {
// Rethrow the exception
private def handleException(e: Throwable): Unit = {
// Rethrow the exception.
// TODO: SPARK-48353 Add error handling for SQL scripts
throw e
}
Expand All @@ -78,7 +87,7 @@ class SqlScriptingExecution(
try {
f
} catch {
case e: Exception =>
case e: Throwable =>
handleException(e)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ trait NonLeafStatementExec extends CompoundStatementExec {
* Logical plan of the parsed statement.
* @param origin
* Origin descriptor for the statement.
* @param args
* A map of parameter names to SQL literal expressions.
* @param isInternal
* Whether the statement originates from the SQL script or it is created during the
* interpretation. Example: DropVariable statements are automatically created at the end of each
Expand Down Expand Up @@ -148,7 +150,7 @@ class SingleStatementExec(

/**
* Builds a DataFrame from the parsedPlan of this SingleStatementExec
* @param session The SparkSession on which the parsedPlan is built
* @param session The SparkSession on which the parsedPlan is built.
* @return
* The DataFrame.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ case class SqlScriptingInterpreter(session: SparkSession) {
*
* @param compound
* CompoundBody for which to build the plan.
* @param args
* A map of parameter names to SQL literal expressions.
* @return
* Iterator through collection of statements to be executed.
*/
Expand Down Expand Up @@ -65,6 +67,8 @@ case class SqlScriptingInterpreter(session: SparkSession) {
*
* @param node
* Root node of the parsed tree.
* @param args
* A map of parameter names to SQL literal expressions.
* @return
* Executable statement.
*/
Expand Down Expand Up @@ -102,7 +106,6 @@ case class SqlScriptingInterpreter(session: SparkSession) {

case CaseStatement(conditions, conditionalBodies, elseBody) =>
val conditionsExec = conditions.map(condition =>
// todo: what to put here for isInternal, in case of simple case statement
new SingleStatementExec(
condition.parsedPlan,
condition.origin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.scripting

import org.apache.spark.SparkConf
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.logical.CompoundBody
import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLConf
import org.apache.spark.sql.exceptions.SqlScriptingException
Expand Down Expand Up @@ -151,7 +151,7 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
| SELECT ?;
| END IF;
|END""".stripMargin
// Define an array with SQL parameters in the correct order
// Define an array with SQL parameters in the correct order.
val args: Array[Any] = Array(5, "greater", "smaller")
checkError(
exception = intercept[SqlScriptingException] {
Expand All @@ -160,4 +160,29 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
condition = "UNSUPPORTED_FEATURE.SQL_SCRIPTING_WITH_POSITIONAL_PARAMETERS",
parameters = Map.empty)
}

test("named params with positional params - should fail") {
val sqlScriptText =
"""
|BEGIN
| SELECT ?;
| IF :param > 10 THEN
| SELECT 1;
| ELSE
| SELECT 2;
| END IF;
|END""".stripMargin
// Define a map with SQL parameters.
val args: Map[String, Any] = Map("param" -> 5)
checkError(
exception = intercept[AnalysisException] {
spark.sql(sqlScriptText, args).asInstanceOf[CompoundBody]
},
condition = "UNBOUND_SQL_PARAMETER",
parameters = Map("name" -> "_16"),
context = ExpectedContext(
fragment = "?",
start = 16,
stop = 16))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,8 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
| END LOOP;
| END LOOP;
|END""".stripMargin
// Execution immediately leaves the outer loop after SELECT,
// so we expect only a single row in the result set.
val expected = Seq(Seq(Row(1)))
verifySqlScriptResult(sqlScriptText, expected)
}
Expand Down

0 comments on commit d604c1d

Please sign in to comment.