Skip to content

Commit

Permalink
Refactor iterator in SqlScriptingExecution
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Dec 3, 2024
1 parent 7a5cf44 commit 3e86522
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ class AstBuilder extends DataTypeAstBuilder

override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = {
val labelCtx = new SqlScriptingLabelContext()
visitCompoundBodyImpl(ctx.compoundBody(), Some("root"), allowVarDeclare = true, labelCtx)
val labelText = labelCtx.enterLabeledScope(None, None)
val script =
visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true, labelCtx)
labelCtx.exitLabeledScope(None)
script
}

private def visitCompoundBodyImpl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,29 @@ class SqlScriptingExecution(
// Frames to keep what is being executed.
private val context: SqlScriptingExecutionContext = {
val ctx = new SqlScriptingExecutionContext()
val executionPlan = interpreter.buildExecutionPlan(sqlScript, args, ctx)
val frame = new SqlScriptingExecutionFrame(executionPlan)
frame.enterScope(sqlScript.label.get)
ctx.frames
.addOne(new SqlScriptingExecutionFrame(
interpreter.buildExecutionPlan(sqlScript, args, ctx)))
interpreter.buildExecutionPlan(sqlScript, args, ctx)
ctx
}

private var current = getNextResult
private var current: Option[DataFrame] = None
private var resultConsumed: Boolean = true

override def hasNext: Boolean = current.isDefined
override def hasNext: Boolean = {
// If the previous result was not consumed, return true if current element exists.
if (!resultConsumed) {
return current.isDefined
}

// If the previous result was consumed, get the next result and return true if it exists.
current = getNextResult
resultConsumed = false
current.isDefined
}

override def next(): DataFrame = {
if (!hasNext) throw SparkException.internalError("No more elements to iterate through.")
val nextDataFrame = current.get
current = getNextResult
nextDataFrame
resultConsumed = true
current.get
}

/** Helper method to iterate get next statements from the first available frame. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.SparkException
*/
class SqlScriptingExecutionContext {
// List of frames that are currently active.
val frames: ListBuffer[SqlScriptingExecutionFrame] = ListBuffer()
val frames: ListBuffer[SqlScriptingExecutionFrame] = ListBuffer.empty

def enterScope(label: String): Unit = {
if (frames.isEmpty) {
Expand All @@ -53,8 +53,7 @@ class SqlScriptingExecutionFrame(
executionPlan: Iterator[CompoundStatementExec]) extends Iterator[CompoundStatementExec] {

// List of scopes that are currently active.
private val scopes: ListBuffer[SqlScriptingExecutionScope] =
ListBuffer(new SqlScriptingExecutionScope("root"))
private val scopes: ListBuffer[SqlScriptingExecutionScope] = ListBuffer.empty

override def hasNext: Boolean = executionPlan.hasNext

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,14 @@ class CompoundBodyExec(
private var scopeEntered = false
private var scopeExited = false

private def enterScope(): Unit = {
def enterScope(): Unit = {
if (label.isDefined && !scopeEntered) {
scopeEntered = true
context.enterScope(label.get)
}
}

private def exitScope(): Unit = {
def exitScope(): Unit = {
if (label.isDefined && !scopeExited) {
scopeExited = true
context.exitScope(label.get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ case class SqlScriptingInterpreter(session: SparkSession) {
compound: CompoundBody,
args: Map[String, Expression],
context: SqlScriptingExecutionContext): Iterator[CompoundStatementExec] = {
transformTreeIntoExecutable(compound, args, context)
.asInstanceOf[CompoundBodyExec].getTreeIterator
val compoundBodyExec = transformTreeIntoExecutable(compound, args, context)
.asInstanceOf[CompoundBodyExec]
context.frames.addOne(new SqlScriptingExecutionFrame(compoundBodyExec.getTreeIterator))
compoundBodyExec.enterScope()
compoundBodyExec.getTreeIterator
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody]
val context = new SqlScriptingExecutionContext()
val executionPlan = interpreter.buildExecutionPlan(compoundBody, args, context)
context.frames.addOne(new SqlScriptingExecutionFrame(executionPlan))

executionPlan.flatMap {
case statement: SingleStatementExec =>
if (statement.isExecuted) {
Expand Down

0 comments on commit 3e86522

Please sign in to comment.