Skip to content

Commit

Permalink
Address comments v1
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Nov 25, 2024
1 parent b25ee02 commit c9162f0
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -417,22 +417,21 @@ class SparkSession private(
script: CompoundBody,
args: Map[String, Expression] = Map.empty): DataFrame = {
val sse = new SqlScriptingExecution(script, this, args)
var df: DataFrame = null
var result: Option[Seq[Row]] = null
var result: Option[Seq[Row]] = None

while (sse.hasNext) {
sse.withErrorHandling() {
df = sse.next()
val df = sse.next()
if (sse.hasNext) {
df.collect()
df.write.format("noop").mode("overwrite").save()
} else {
// Collect results from the last DataFrame
result = Some(df.collect().toSeq)
}
}
}

if (result == null) {
if (result.isEmpty) {
emptyDataFrame
} else {
val attributes = DataTypeUtils.toAttributes(result.get.head.schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.CompoundBody
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody}

/**
* SQL scripting executor - executes script and returns result statements.
Expand All @@ -41,26 +41,33 @@ class SqlScriptingExecution(
if (!hasNext) {
throw new NoSuchElementException("No more statements to execute")
}
val nextDataFrame = current.get.asInstanceOf[SingleStatementExec].buildDataFrame(session)
val nextDataFrame = current.get
current = getNextResult
nextDataFrame
}

/** Helper method to iterate through statements until next result statement is encountered */
private def getNextResult: Option[CompoundStatementExec] = {
private def getNextResult: Option[DataFrame] = {
var currentStatement = if (executionPlan.hasNext) Some(executionPlan.next()) else None
// While we don't have a result statement, execute the statements
while (currentStatement.isDefined && !currentStatement.get.isResult) {
while (currentStatement.isDefined) {
currentStatement match {
case Some(stmt: SingleStatementExec) if !stmt.isExecuted =>
withErrorHandling() {
stmt.buildDataFrame(session).collect()
val df = stmt.buildDataFrame(session)
if (df.logicalPlan.isInstanceOf[CommandResult]) {
// If the statement is not a result, we need to write it to a noop sink to execute it
df.write.format("noop").mode("overwrite").save()
} else {
// If the statement is a result, we need to return it to the caller
return Some(df)
}
}
case _ => // pass
}
currentStatement = if (executionPlan.hasNext) Some(executionPlan.next()) else None
}
currentStatement
None
}

private def handleException(e: Exception): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.analysis.NameParameterizedQuery
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
import org.apache.spark.sql.errors.SqlScriptingErrors
import org.apache.spark.sql.types.BooleanType
Expand All @@ -38,11 +38,6 @@ sealed trait CompoundStatementExec extends Logging {
*/
val isInternal: Boolean = false

/**
* Whether the statement originates from the SQL statement that returns the result.
*/
def isResult: Boolean = false

/**
* Reset execution of the current node.
*/
Expand Down Expand Up @@ -133,17 +128,14 @@ class SingleStatementExec(
/**
* Plan with named parameters.
*/
lazy val resolvedPlan: LogicalPlan = {
private lazy val preparedPlan: LogicalPlan = {
if (args.nonEmpty) {
NameParameterizedQuery(parsedPlan, args)
} else {
parsedPlan
}
}

/** Statement is result if it is a SELECT query, and it is not in control flow condition */
override def isResult: Boolean = parsedPlan.isInstanceOf[Project] && !isExecuted

/**
* Get the SQL query text corresponding to this statement.
* @return
Expand All @@ -155,14 +147,13 @@ class SingleStatementExec(
}

/**
* Builds a DataFrame from the parsedPlan of this SingleStatementExec,
* logging Origin.sqlText if it exists
* Builds a DataFrame from the parsedPlan of this SingleStatementExec
* @param session The SparkSession on which the parsedPlan is built
* @return
* The DataFrame.
*/
def buildDataFrame(session: SparkSession): DataFrame = {
Dataset.ofRows(session, resolvedPlan)
Dataset.ofRows(session, preparedPlan)
}

override def reset(): Unit = isExecuted = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
}
}

test("last statement without result") {
test("script without result statement") {
val sqlScript =
"""
|BEGIN
Expand All @@ -111,6 +111,15 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
verifySqlScriptResult(sqlScript, Seq.empty)
}

test("empty script") {
val sqlScript =
"""
|BEGIN
|END
|""".stripMargin
verifySqlScriptResult(sqlScript, Seq.empty)
}

test("named params") {
val sqlScriptText =
"""
Expand Down

0 comments on commit c9162f0

Please sign in to comment.