Skip to content

Commit

Permalink
add drop variables
Browse files Browse the repository at this point in the history
  • Loading branch information
dusantism-db committed Nov 13, 2024
1 parent 2200076 commit 0ed3e2b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateMap, CreateNamedStruct, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable}
import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
import org.apache.spark.sql.errors.SqlScriptingErrors
import org.apache.spark.sql.types.BooleanType
Expand Down Expand Up @@ -723,7 +723,7 @@ class ForStatementExec(
}
variablesMap.keys.toSeq
.map(colName => createSetVarExec(colName, variablesMap(colName)))
.foreach(exec => exec.buildDataFrame(session).collect())
.foreach(setVarExec => setVarExec.buildDataFrame(session).collect())
state = ForState.Body
body.reset()
next()
Expand All @@ -738,7 +738,7 @@ class ForStatementExec(
leaveStatementExec.hasBeenMatched = true
}
interrupted = true
// drop vars
dropVars()
return retStmt
case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched =>
if (label.contains(iterStatementExec.label)) {
Expand All @@ -753,6 +753,11 @@ class ForStatementExec(
if (!body.getTreeIterator.hasNext) {
currRow += 1
state = ForState.VariableAssignment

// on final iteration, drop variables
if (currRow == cachedQueryResult().length) {
dropVars()
}
}
retStmt
}
Expand Down Expand Up @@ -795,6 +800,13 @@ class ForStatementExec(
variablesMap
}

private def dropVars() = {
variablesMap.keys.toSeq
.map(colName => createDropVarExec(colName))
.foreach(dropVarExec => dropVarExec.buildDataFrame(session).collect())
areVariablesDeclared = false
}

private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = {
val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null")
val declareVariable = CreateVariable(
Expand All @@ -815,13 +827,19 @@ class ForStatementExec(
new SingleStatementExec(setIdentifierToCurrentRow, Origin(), isInternal = true)
}

private def createDropVarExec(varName: String): SingleStatementExec = {
val dropVar = DropVariable(UnresolvedIdentifier(Seq(varName)), ifExists = true)
new SingleStatementExec(dropVar, Origin(), isInternal = true)
}

override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator

override def reset(): Unit = {
state = ForState.VariableAssignment
isResultCacheValid = false
currRow = 0
variablesMap = Map()
areVariablesDeclared = false
body.reset()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
statement: LeafStatementExec): Boolean = evaluator.evaluateLoopBooleanCondition()
}

case class TestForStatementQuery(numberOfRows: Int, description: String)
case class TestForStatementQuery(numberOfRows: Int, columnName: String, description: String)
extends SingleStatementExec(
DummyLogicalPlan(),
Origin(startIndex = Some(0), stopIndex = Some(description.length)),
isInternal = false) {
override def buildDataFrame(session: SparkSession): DataFrame = {
val data = Seq.range(0, numberOfRows).map(Row(_))
val schema = List(StructField("intCol", IntegerType))
val schema = List(StructField(columnName, IntegerType))

spark.createDataFrame(
spark.sparkContext.parallelize(data),
Expand Down Expand Up @@ -708,7 +708,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement - enters body once") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(1, "query1"),
query = TestForStatementQuery(1, "intCol", "query1"),
variableName = Some("x"),
body = new CompoundBodyExec(Seq(TestLeafStatement("body"))),
label = Some("for1"),
Expand All @@ -724,7 +724,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement - enters body with multiple statements multiple times") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = Some("x"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("statement1"),
Expand All @@ -745,7 +745,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement - empty result") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(0, "query1"),
query = TestForStatementQuery(0, "intCol", "query1"),
variableName = Some("x"),
body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
label = Some("for1"),
Expand All @@ -759,11 +759,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement - nested") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = Some("x"),
body = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query2"),
query = TestForStatementQuery(2, "intCol1", "query2"),
variableName = Some("y"),
body = new CompoundBodyExec(Seq(TestLeafStatement("body"))),
label = Some("for2"),
Expand All @@ -786,7 +786,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement no variable - enters body once") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(1, "query1"),
query = TestForStatementQuery(1, "intCol", "query1"),
variableName = None,
body = new CompoundBodyExec(Seq(TestLeafStatement("body"))),
label = Some("for1"),
Expand All @@ -800,7 +800,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement no variable - enters body with multiple statements multiple times") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = None,
body = new CompoundBodyExec(Seq(
TestLeafStatement("statement1"),
Expand All @@ -816,7 +816,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement no variable - empty result") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(0, "query1"),
query = TestForStatementQuery(0, "intCol", "query1"),
variableName = None,
body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
label = Some("for1"),
Expand All @@ -830,11 +830,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement no variable - nested") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = None,
body = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query2"),
query = TestForStatementQuery(2, "intCol1", "query2"),
variableName = None,
body = new CompoundBodyExec(Seq(TestLeafStatement("body"))),
label = Some("for2"),
Expand All @@ -852,7 +852,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement - iterate") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = Some("x"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("statement1"),
Expand All @@ -874,7 +874,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement - leave") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = Some("x"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("statement1"),
Expand All @@ -894,11 +894,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement - nested - iterate outer loop") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = Some("x"),
body = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query2"),
query = TestForStatementQuery(2, "intCol1", "query2"),
variableName = Some("y"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
Expand All @@ -924,11 +924,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement - nested - leave outer loop") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = Some("x"),
body = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query2"),
query = TestForStatementQuery(2, "intCol", "query2"),
variableName = Some("y"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
Expand All @@ -952,7 +952,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement no variable - iterate") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = None,
body = new CompoundBodyExec(Seq(
TestLeafStatement("statement1"),
Expand All @@ -969,7 +969,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement no variable - leave") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = None,
body = new CompoundBodyExec(Seq(
TestLeafStatement("statement1"),
Expand All @@ -986,11 +986,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement no variable - nested - iterate outer loop") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = None,
body = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query2"),
query = TestForStatementQuery(2, "intCol1", "query2"),
variableName = None,
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
Expand All @@ -1011,11 +1011,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("for statement no variable - nested - leave outer loop") {
val iter = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query1"),
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = None,
body = new CompoundBodyExec(Seq(
new ForStatementExec(
query = TestForStatementQuery(2, "query2"),
query = TestForStatementQuery(2, "intCol1", "query2"),
variableName = None,
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
Expand Down

0 comments on commit 0ed3e2b

Please sign in to comment.