Skip to content

Commit

Permalink
Fix context init
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Dec 4, 2024
1 parent 3e86522 commit 24f53b1
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class SqlScriptingExecution(
// Frames to keep what is being executed.
private val context: SqlScriptingExecutionContext = {
val ctx = new SqlScriptingExecutionContext()
interpreter.buildExecutionPlan(sqlScript, args, ctx)
val executionPlan = interpreter.buildExecutionPlan(sqlScript, args, ctx)
ctx.frames.addOne(new SqlScriptingExecutionFrame(executionPlan))
ctx
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ class CompoundBodyExec(
private var scopeExited = false

def enterScope(): Unit = {
if (label.isDefined && !scopeEntered) {
if (context != null && label.isDefined && !scopeEntered) {
scopeEntered = true
context.enterScope(label.get)
}
}

def exitScope(): Unit = {
if (label.isDefined && !scopeExited) {
if (context != null && label.isDefined && !scopeExited) {
scopeExited = true
context.exitScope(label.get)
}
Expand Down Expand Up @@ -218,6 +218,7 @@ class CompoundBodyExec(

@scala.annotation.tailrec
override def next(): CompoundStatementExec = {
enterScope()
curr match {
case None => throw SparkException.internalError(
"No more elements to iterate through in the current SQL compound statement.")
Expand All @@ -233,11 +234,6 @@ class CompoundBodyExec(
curr = if (localIterator.hasNext) Some(localIterator.next()) else None
statement
case Some(body: NonLeafStatementExec) =>
body match {
case compound: CompoundBodyExec =>
compound.enterScope()
case _ => // pass
}
if (body.getTreeIterator.hasNext) {
body.getTreeIterator.next() match {
case leaveStatement: LeaveStatementExec =>
Expand Down Expand Up @@ -273,6 +269,9 @@ class CompoundBodyExec(
// Stop the iteration.
stopIteration = true

// Exit scope if leave statement is encountered.
exitScope()

// TODO: Variable cleanup (once we add SQL script execution logic).
// TODO: Add interpreter tests as well.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,8 @@ case class SqlScriptingInterpreter(session: SparkSession) {
compound: CompoundBody,
args: Map[String, Expression],
context: SqlScriptingExecutionContext): Iterator[CompoundStatementExec] = {
val compoundBodyExec = transformTreeIntoExecutable(compound, args, context)
.asInstanceOf[CompoundBodyExec]
context.frames.addOne(new SqlScriptingExecutionFrame(compoundBodyExec.getTreeIterator))
compoundBodyExec.enterScope()
compoundBodyExec.getTreeIterator
transformTreeIntoExecutable(compound, args, context)
.asInstanceOf[CompoundBodyExec].getTreeIterator
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,23 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
case class TestCompoundBody(
statements: Seq[CompoundStatementExec],
label: Option[String] = None,
context: SqlScriptingExecutionContext = new SqlScriptingExecutionContext)
context: SqlScriptingExecutionContext = null)
extends CompoundBodyExec(statements, label, context)

case class TestForStatement(
query: SingleStatementExec,
variableName: Option[String],
body: CompoundBodyExec,
override val label: Option[String],
session: SparkSession)
session: SparkSession,
context: SqlScriptingExecutionContext = null)
extends ForStatementExec(
query,
variableName,
body,
label,
session,
new SqlScriptingExecutionContext) {
context) {
override def reset(): Unit = ()
}

Expand All @@ -64,7 +65,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
Origin(startIndex = Some(0), stopIndex = Some(description.length)),
Map.empty,
isInternal = false,
new SqlScriptingExecutionContext
null
)

case class DummyLogicalPlan() extends LeafNode {
Expand All @@ -78,7 +79,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
Origin(startIndex = Some(0), stopIndex = Some(description.length)),
Map.empty,
isInternal = false,
new SqlScriptingExecutionContext)
null
)

class LoopBooleanConditionEvaluator(condition: TestLoopCondition) {
private var callCount: Int = 0
Expand Down Expand Up @@ -126,7 +128,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
Origin(startIndex = Some(0), stopIndex = Some(description.length)),
Map.empty,
isInternal = false,
new SqlScriptingExecutionContext) {
null) {
override def buildDataFrame(session: SparkSession): DataFrame = {
val data = Seq.range(0, numberOfRows).map(Row(_))
val schema = List(StructField(columnName, IntegerType))
Expand Down Expand Up @@ -445,13 +447,18 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("leave compound block") {
val context = new SqlScriptingExecutionContext
val labelText = "lbl"
val iter = TestCompoundBody(
statements = Seq(
TestLeafStatement("one"),
new LeaveStatementExec("lbl")
),
label = Some("lbl")
label = Some(labelText),
context = context
).getTreeIterator
context.frames.addOne(new SqlScriptingExecutionFrame(iter))
context.enterScope(labelText)
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("one", "lbl"))
}
Expand Down Expand Up @@ -792,23 +799,31 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("for statement - nested") {
val context = new SqlScriptingExecutionContext
val labelText = "lbl"
val iter = TestCompoundBody(Seq(
TestForStatement(
query = MockQuery(2, "intCol", "query1"),
variableName = Some("x"),
label = Some("for1"),
session = spark,
body = TestCompoundBody(Seq(
TestForStatement(
query = MockQuery(2, "intCol1", "query2"),
variableName = Some("y"),
label = Some("for2"),
session = spark,
body = TestCompoundBody(Seq(TestLeafStatement("body")))
)
))
)
)).getTreeIterator
body =
TestCompoundBody(Seq(
TestForStatement(
query = MockQuery(2, "intCol1", "query2"),
variableName = Some("y"),
label = Some("for2"),
session = spark,
body = TestCompoundBody(Seq(TestLeafStatement("body"))),
context = context
)
),
context = context),
context = context
)),
label = Some(labelText)).getTreeIterator
context.frames.addOne(new SqlScriptingExecutionFrame(iter))
context.enterScope(labelText)
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq(
"body",
Expand Down Expand Up @@ -842,7 +857,6 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("for statement no variable - enters body with multiple statements multiple times") {
val context = new SqlScriptingExecutionContext
val iter = TestCompoundBody(Seq(
TestForStatement(
query = MockQuery(2, "intCol", "query1"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
args: Map[String, Expression] = Map.empty): Array[DataFrame] = {
val interpreter = SqlScriptingInterpreter(spark)
val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody]

// Initialize context so scopes can be entered correctly.
val context = new SqlScriptingExecutionContext()
val executionPlan = interpreter.buildExecutionPlan(compoundBody, args, context)
context.frames.addOne(new SqlScriptingExecutionFrame(executionPlan))

executionPlan.flatMap {
case statement: SingleStatementExec =>
Expand Down

0 comments on commit 24f53b1

Please sign in to comment.