Skip to content

Commit

Permalink
Add isScope
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Dec 4, 2024
1 parent 24f53b1 commit 5f6d075
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,13 @@ class AstBuilder extends DataTypeAstBuilder
override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = {
val labelCtx = new SqlScriptingLabelContext()
val labelText = labelCtx.enterLabeledScope(None, None)
val script =
visitCompoundBodyImpl(ctx.compoundBody(), Some(labelText), allowVarDeclare = true, labelCtx)
val script = visitCompoundBodyImpl(
ctx.compoundBody(),
Some(labelText),
allowVarDeclare = true,
labelCtx,
isScope = true
)
labelCtx.exitLabeledScope(None)
script
}
Expand All @@ -155,7 +160,8 @@ class AstBuilder extends DataTypeAstBuilder
ctx: CompoundBodyContext,
label: Option[String],
allowVarDeclare: Boolean,
labelCtx: SqlScriptingLabelContext): CompoundBody = {
labelCtx: SqlScriptingLabelContext,
isScope: Boolean): CompoundBody = {
val buff = ListBuffer[CompoundPlanStatement]()
ctx.compoundStatements.forEach(
compoundStatement => buff += visitCompoundStatementImpl(compoundStatement, labelCtx))
Expand Down Expand Up @@ -187,7 +193,7 @@ class AstBuilder extends DataTypeAstBuilder
case _ =>
}

CompoundBody(buff.toSeq, label)
CompoundBody(buff.toSeq, label, isScope)
}

private def visitBeginEndCompoundBlockImpl(
Expand All @@ -199,7 +205,8 @@ class AstBuilder extends DataTypeAstBuilder
ctx.compoundBody(),
Some(labelText),
allowVarDeclare = true,
labelCtx
labelCtx,
isScope = true
)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
body
Expand Down Expand Up @@ -251,10 +258,12 @@ class AstBuilder extends DataTypeAstBuilder
OneRowRelation()))
}),
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
body =>
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)
),
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
body =>
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)
)
)
}
Expand All @@ -271,7 +280,13 @@ class AstBuilder extends DataTypeAstBuilder
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
None,
allowVarDeclare = false,
labelCtx,
isScope = false
)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

WhileStatement(condition, body, Some(labelText))
Expand All @@ -288,7 +303,8 @@ class AstBuilder extends DataTypeAstBuilder
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
body =>
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)
)

if (conditions.length != conditionalBodies.length) {
Expand All @@ -301,7 +317,8 @@ class AstBuilder extends DataTypeAstBuilder
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
body =>
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)
))
}

Expand All @@ -318,7 +335,8 @@ class AstBuilder extends DataTypeAstBuilder
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
body =>
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)
)

if (conditions.length != conditionalBodies.length) {
Expand All @@ -331,7 +349,8 @@ class AstBuilder extends DataTypeAstBuilder
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(
body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx)
body =>
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)
))
}

Expand All @@ -347,7 +366,13 @@ class AstBuilder extends DataTypeAstBuilder
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))}
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
None,
allowVarDeclare = false,
labelCtx,
isScope = false
)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

RepeatStatement(condition, body, Some(labelText))
Expand All @@ -363,7 +388,13 @@ class AstBuilder extends DataTypeAstBuilder
SingleStatement(visitQuery(queryCtx))
}
val varName = Option(ctx.multipartIdentifier()).map(_.getText)
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
None,
allowVarDeclare = false,
labelCtx,
isScope = false
)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

ForStatement(query, varName, body, Some(labelText))
Expand Down Expand Up @@ -436,7 +467,13 @@ class AstBuilder extends DataTypeAstBuilder
labelCtx: SqlScriptingLabelContext): LoopStatement = {
val labelText =
labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
None,
allowVarDeclare = false,
labelCtx,
isScope = false
)
labelCtx.exitLabeledScope(Option(ctx.beginLabel()))

LoopStatement(body, Some(labelText))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,14 @@ case class SingleStatement(parsedPlan: LogicalPlan)
*/
case class CompoundBody(
collection: Seq[CompoundPlanStatement],
label: Option[String]) extends Command with CompoundPlanStatement {
label: Option[String],
isScope: Boolean) extends Command with CompoundPlanStatement {

override def children: Seq[LogicalPlan] = collection

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = {
CompoundBody(newChildren.map(_.asInstanceOf[CompoundPlanStatement]), label)
CompoundBody(newChildren.map(_.asInstanceOf[CompoundPlanStatement]), label, isScope)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class SingleStatementExec(
class CompoundBodyExec(
statements: Seq[CompoundStatementExec],
label: Option[String] = None,
isScope: Boolean,
context: SqlScriptingExecutionContext)
extends NonLeafStatementExec {

Expand All @@ -182,16 +183,18 @@ class CompoundBodyExec(
private var scopeEntered = false
private var scopeExited = false

def enterScope(): Unit = {
if (context != null && label.isDefined && !scopeEntered) {
protected def enterScope(): Unit = {
if (isScope && !scopeEntered) {
scopeEntered = true
scopeExited = false
context.enterScope(label.get)
}
}

def exitScope(): Unit = {
if (context != null && label.isDefined && !scopeExited) {
protected def exitScope(): Unit = {
if (isScope && !scopeExited) {
scopeExited = true
scopeEntered = false
context.exitScope(label.get)
}
}
Expand All @@ -209,11 +212,7 @@ class CompoundBodyExec(
case _ => throw SparkException.internalError(
"Unknown statement type encountered during SQL script interpretation.")
}
val result = !stopIteration && (localIterator.hasNext || childHasNext)
if (!result) {
exitScope()
}
result
!stopIteration && (localIterator.hasNext || childHasNext)
}

@scala.annotation.tailrec
Expand All @@ -235,6 +234,10 @@ class CompoundBodyExec(
statement
case Some(body: NonLeafStatementExec) =>
if (body.getTreeIterator.hasNext) {
body match {
case exec: CompoundBodyExec => exec.enterScope()
case _ => // pass
}
body.getTreeIterator.next() match {
case leaveStatement: LeaveStatementExec =>
handleLeaveStatement(leaveStatement)
Expand All @@ -245,6 +248,10 @@ class CompoundBodyExec(
case other => other
}
} else {
body match {
case exec: CompoundBodyExec => exec.exitScope()
case _ => // pass
}
curr = if (localIterator.hasNext) Some(localIterator.next()) else None
next()
}
Expand Down Expand Up @@ -871,6 +878,7 @@ class ForStatementExec(
dropVariablesExec = new CompoundBodyExec(
variablesMap.keys.toSeq.map(colName => createDropVarExec(colName)),
None,
isScope = false,
context
)
ForState.VariableCleanup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ case class SqlScriptingInterpreter(session: SparkSession) {
args: Map[String, Expression],
context: SqlScriptingExecutionContext): CompoundStatementExec =
node match {
case CompoundBody(collection, label) =>
case CompoundBody(collection, label, isScope) =>
// TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing.
val variables = collection.flatMap {
case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan)
Expand All @@ -91,6 +91,7 @@ case class SqlScriptingInterpreter(session: SparkSession) {
new CompoundBodyExec(
collection.map(st => transformTreeIntoExecutable(st, args, context)) ++ dropVariables,
label,
isScope,
context)

case IfElseStatement(conditions, conditionalBodies, elseBody) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,15 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
case class TestCompoundBody(
statements: Seq[CompoundStatementExec],
label: Option[String] = None,
isScope: Boolean = false,
context: SqlScriptingExecutionContext = null)
extends CompoundBodyExec(statements, label, context)
extends CompoundBodyExec(statements, label, isScope, context) {

override def enterScope(): Unit = ()

override def exitScope(): Unit = ()

}

case class TestForStatement(
query: SingleStatementExec,
Expand Down

0 comments on commit 5f6d075

Please sign in to comment.