From 24f53b1a9ff4a29cdea63df7d3819db4d29f050b Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 4 Dec 2024 11:28:33 +0100 Subject: [PATCH] Fix context init --- .../sql/scripting/SqlScriptingExecution.scala | 3 +- .../scripting/SqlScriptingExecutionNode.scala | 13 +++-- .../scripting/SqlScriptingInterpreter.scala | 7 +-- .../SqlScriptingExecutionNodeSuite.scala | 52 ++++++++++++------- .../SqlScriptingInterpreterSuite.scala | 3 ++ 5 files changed, 46 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 37b2f89af376c..2c2eb0695ef70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -40,7 +40,8 @@ class SqlScriptingExecution( // Frames to keep what is being executed. private val context: SqlScriptingExecutionContext = { val ctx = new SqlScriptingExecutionContext() - interpreter.buildExecutionPlan(sqlScript, args, ctx) + val executionPlan = interpreter.buildExecutionPlan(sqlScript, args, ctx) + ctx.frames.addOne(new SqlScriptingExecutionFrame(executionPlan)) ctx } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 01d178b2db028..7a35c5a509c05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -183,14 +183,14 @@ class CompoundBodyExec( private var scopeExited = false def enterScope(): Unit = { - if (label.isDefined && !scopeEntered) { + if (context != null && label.isDefined && !scopeEntered) { scopeEntered = true context.enterScope(label.get) } } def exitScope(): Unit = { - if (label.isDefined && !scopeExited) { + if (context != null && label.isDefined && !scopeExited) { scopeExited = true context.exitScope(label.get) } @@ -218,6 +218,7 @@ class CompoundBodyExec( @scala.annotation.tailrec override def next(): CompoundStatementExec = { + enterScope() curr match { case None => throw SparkException.internalError( "No more elements to iterate through in the current SQL compound statement.") @@ -233,11 +234,6 @@ class CompoundBodyExec( curr = if (localIterator.hasNext) Some(localIterator.next()) else None statement case Some(body: NonLeafStatementExec) => - body match { - case compound: CompoundBodyExec => - compound.enterScope() - case _ => // pass - } if (body.getTreeIterator.hasNext) { body.getTreeIterator.next() match { case leaveStatement: LeaveStatementExec => @@ -273,6 +269,9 @@ class CompoundBodyExec( // Stop the iteration. stopIteration = true + // Exit scope if leave statement is encountered. + exitScope() + // TODO: Variable cleanup (once we add SQL script execution logic). // TODO: Add interpreter tests as well. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 4abee1a647bf5..4aac8f5fe1364 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -46,11 +46,8 @@ case class SqlScriptingInterpreter(session: SparkSession) { compound: CompoundBody, args: Map[String, Expression], context: SqlScriptingExecutionContext): Iterator[CompoundStatementExec] = { - val compoundBodyExec = transformTreeIntoExecutable(compound, args, context) - .asInstanceOf[CompoundBodyExec] - context.frames.addOne(new SqlScriptingExecutionFrame(compoundBodyExec.getTreeIterator)) - compoundBodyExec.enterScope() - compoundBodyExec.getTreeIterator + transformTreeIntoExecutable(compound, args, context) + .asInstanceOf[CompoundBodyExec].getTreeIterator } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index f7029e6ff2e68..bcd0618947d80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -35,7 +35,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case class TestCompoundBody( statements: Seq[CompoundStatementExec], label: Option[String] = None, - context: SqlScriptingExecutionContext = new SqlScriptingExecutionContext) + context: SqlScriptingExecutionContext = null) extends CompoundBodyExec(statements, label, context) case class TestForStatement( @@ -43,14 +43,15 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi variableName: Option[String], body: CompoundBodyExec, override val label: Option[String], - session: SparkSession) + session: SparkSession, + context: SqlScriptingExecutionContext = null) extends ForStatementExec( query, variableName, body, label, session, - new SqlScriptingExecutionContext) { + context) { override def reset(): Unit = () } @@ -64,7 +65,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi Origin(startIndex = Some(0), stopIndex = Some(description.length)), Map.empty, isInternal = false, - new SqlScriptingExecutionContext + null ) case class DummyLogicalPlan() extends LeafNode { @@ -78,7 +79,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi Origin(startIndex = Some(0), stopIndex = Some(description.length)), Map.empty, isInternal = false, - new SqlScriptingExecutionContext) + null + ) class LoopBooleanConditionEvaluator(condition: TestLoopCondition) { private var callCount: Int = 0 @@ -126,7 +128,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi Origin(startIndex = Some(0), stopIndex = Some(description.length)), Map.empty, isInternal = false, - new SqlScriptingExecutionContext) { + null) { override def buildDataFrame(session: SparkSession): DataFrame = { val data = Seq.range(0, numberOfRows).map(Row(_)) val schema = List(StructField(columnName, IntegerType)) @@ -445,13 +447,18 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("leave compound block") { + val context = new SqlScriptingExecutionContext + val labelText = "lbl" val iter = TestCompoundBody( statements = Seq( TestLeafStatement("one"), new LeaveStatementExec("lbl") ), - label = Some("lbl") + label = Some(labelText), + context = context ).getTreeIterator + context.frames.addOne(new SqlScriptingExecutionFrame(iter)) + context.enterScope(labelText) val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq("one", "lbl")) } @@ -792,23 +799,31 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("for statement - nested") { + val context = new SqlScriptingExecutionContext + val labelText = "lbl" val iter = TestCompoundBody(Seq( TestForStatement( query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), label = Some("for1"), session = spark, - body = TestCompoundBody(Seq( - TestForStatement( - query = MockQuery(2, "intCol1", "query2"), - variableName = Some("y"), - label = Some("for2"), - session = spark, - body = TestCompoundBody(Seq(TestLeafStatement("body"))) - ) - )) - ) - )).getTreeIterator + body = + TestCompoundBody(Seq( + TestForStatement( + query = MockQuery(2, "intCol1", "query2"), + variableName = Some("y"), + label = Some("for2"), + session = spark, + body = TestCompoundBody(Seq(TestLeafStatement("body"))), + context = context + ) + ), + context = context), + context = context + )), + label = Some(labelText)).getTreeIterator + context.frames.addOne(new SqlScriptingExecutionFrame(iter)) + context.enterScope(labelText) val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "body", @@ -842,7 +857,6 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } test("for statement no variable - enters body with multiple statements multiple times") { - val context = new SqlScriptingExecutionContext val iter = TestCompoundBody(Seq( TestForStatement( query = MockQuery(2, "intCol", "query1"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index d2c9f7f0ea4c5..0e3da5fe78210 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -45,8 +45,11 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { args: Map[String, Expression] = Map.empty): Array[DataFrame] = { val interpreter = SqlScriptingInterpreter(spark) val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody] + + // Initialize context so scopes can be entered correctly. val context = new SqlScriptingExecutionContext() val executionPlan = interpreter.buildExecutionPlan(compoundBody, args, context) + context.frames.addOne(new SqlScriptingExecutionFrame(executionPlan)) executionPlan.flatMap { case statement: SingleStatementExec =>