Skip to content

Commit

Permalink
fix for nested arrays, and change drop variable logic to work with le…
Browse files Browse the repository at this point in the history
…ave/iterate/normal case
  • Loading branch information
dusantism-db committed Nov 14, 2024
1 parent 0ed3e2b commit 6358c11
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateMap, CreateNamedStruct, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable}
import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
import org.apache.spark.sql.errors.SqlScriptingErrors
Expand Down Expand Up @@ -667,8 +667,7 @@ class LoopStatementExec(
* Executable node for ForStatement.
* @param query Executable node for the query.
* @param variableName Name of variable used for accessing current row during iteration.
* @param body Executable node for the body. If variableName is not None, will have DropVariable
* as the last statement.
* @param body Executable node for the body.
* @param label Label set to ForStatement by user or None otherwise.
* @param session Spark session that SQL script is executed within.
*/
Expand All @@ -680,7 +679,7 @@ class ForStatementExec(
session: SparkSession) extends NonLeafStatementExec {

private object ForState extends Enumeration {
val VariableAssignment, Body = Value
val VariableAssignment, Body, VariableCleanup = Value
}
private var state = ForState.VariableAssignment
private var currRow = 0
Expand All @@ -690,6 +689,9 @@ class ForStatementExec(
// (variableName -> variableExpression)
private var variablesMap: Map[String, Expression] = Map()

// compound body used for dropping variables
private var dropVariablesExec: CompoundBodyExec = null

private var queryResult: Array[Row] = null
private var isResultCacheValid = false
private def cachedQueryResult(): Array[Row] = {
Expand All @@ -707,15 +709,20 @@ class ForStatementExec(

private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
override def hasNext: Boolean =
!interrupted && cachedQueryResult().length > 0 && currRow < cachedQueryResult().length
override def hasNext: Boolean = {
val resultSize = cachedQueryResult().length
val ret = state == ForState.VariableCleanup ||
(!interrupted && resultSize > 0 && currRow < resultSize)
ret
}

override def next(): CompoundStatementExec = state match {

case ForState.VariableAssignment =>
variablesMap = createVariablesMapFromRow(currRow)

if (!areVariablesDeclared) {
// create and execute declare var statements
variablesMap.keys.toSeq
.map(colName => createDeclareVarExec(colName, variablesMap(colName)))
.foreach(declareVarExec => declareVarExec.buildDataFrame(session).collect())
Expand All @@ -738,49 +745,66 @@ class ForStatementExec(
leaveStatementExec.hasBeenMatched = true
}
interrupted = true
// If this for statement encounters LEAVE, it will either not be executed ever
// again, or it will be reset before being executed.
// In either case, variables will not
// be dropped normally, from ForState.VariableCleanup, so we drop them here.
dropVars()
return retStmt
case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched =>
if (label.contains(iterStatementExec.label)) {
iterStatementExec.hasBeenMatched = true
} else {
// if an outer loop is being iterated, this for statement will either not be
// executed ever again, or it will be reset before being executed.
// In either case, variables will not
// be dropped normally, from ForState.VariableCleanup, so we drop them here.
dropVars()
}
currRow += 1
state = ForState.VariableAssignment
switchStateFromBody()
return retStmt
case _ =>
}

if (!body.getTreeIterator.hasNext) {
currRow += 1
state = ForState.VariableAssignment

// on final iteration, drop variables
if (currRow == cachedQueryResult().length) {
dropVars()
}
switchStateFromBody()
}
retStmt

case ForState.VariableCleanup =>
val ret = dropVariablesExec.getTreeIterator.next()
if (!dropVariablesExec.getTreeIterator.hasNext) {
state = ForState.VariableAssignment
}
ret
}
}

/**
* Creates a Catalyst expression from Scala value.
* Creates a Catalyst expression from Scala value.<br>
* See https://spark.apache.org/docs/latest/sql-ref-datatypes.html for Spark -> Scala mappings
*/
private def createExpressionFromValue(value: Any): Expression = value match {
case m: Map[_, _] =>
// arguments of CreateMap are in the format: (key1, val1, key2, val2, ...)
val mapArgs = m.keys.toSeq.flatMap { key =>
Seq(createExpressionFromValue(key), createExpressionFromValue(m(key)))
}
CreateMap(mapArgs, false)
CreateMap(mapArgs, useStringTypeWhenEmpty = false)

// structs match this case
case s: Row =>
// struct types match this case
// arguments of CreateNamedStruct are in the format: (name1, val1, name2, val2, ...)
val namedStructArgs = s.schema.names.toSeq.flatMap { colName =>
val valueExpression = createExpressionFromValue(s.getAs(colName))
Seq(Literal(colName), valueExpression)
}
CreateNamedStruct(namedStructArgs)

// arrays match this case
case a: collection.Seq[_] =>
val arrayArgs = a.toSeq.map(createExpressionFromValue(_))
CreateArray(arrayArgs, useStringTypeWhenEmpty = false)
case _ => Literal(value)
}

Expand All @@ -800,13 +824,24 @@ class ForStatementExec(
variablesMap
}

private def dropVars() = {
private def dropVars(): Unit = {
variablesMap.keys.toSeq
.map(colName => createDropVarExec(colName))
.foreach(dropVarExec => dropVarExec.buildDataFrame(session).collect())
areVariablesDeclared = false
}

private def switchStateFromBody(): Unit = {
currRow += 1
state = if (currRow < cachedQueryResult().length) ForState.VariableAssignment
else {
dropVariablesExec = new CompoundBodyExec(
variablesMap.keys.toSeq.map(colName => createDropVarExec(colName))
)
ForState.VariableCleanup
}
}

private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = {
val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null")
val declareVariable = CreateVariable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
case leaveStmt: LeaveStatementExec => leaveStmt.label
case iterateStmt: IterateStatementExec => iterateStmt.label
case forStmt: ForStatementExec => forStmt.label.get
case _: SingleStatementExec => "SingleStatementExec"
case _ => fail("Unexpected statement type")
}

Expand Down Expand Up @@ -717,7 +718,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
)).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq(
"body"
"body",
"SingleStatementExec", // drop local var
"SingleStatementExec" // drop local var
))
}

Expand All @@ -738,7 +741,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
"statement1",
"statement2",
"statement1",
"statement2"
"statement2",
"SingleStatementExec", // drop local var
"SingleStatementExec", // drop local var
))
}

Expand Down Expand Up @@ -778,8 +783,14 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
assert(statements === Seq(
"body",
"body",
"SingleStatementExec", // drop inner local var
"SingleStatementExec", // drop inner local var
"body",
"body",
"body"
"SingleStatementExec", // drop inner local var
"SingleStatementExec", // drop inner local var
"SingleStatementExec", // drop outer local var
"SingleStatementExec", // drop outer local var
))
}

Expand All @@ -794,7 +805,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
)
)).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("body"))
assert(statements === Seq(
"body",
"SingleStatementExec", // drop local var
))
}

test("for statement no variable - enters body with multiple statements multiple times") {
Expand All @@ -810,7 +824,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
)
)).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("statement1", "statement2", "statement1", "statement2"))
assert(statements === Seq(
"statement1", "statement2", "statement1", "statement2",
"SingleStatementExec", // drop local var
))
}

test("for statement no variable - empty result") {
Expand Down Expand Up @@ -846,7 +863,13 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
)
)).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("body", "body", "body", "body"))
assert(statements === Seq(
"body", "body",
"SingleStatementExec", // drop inner local var
"body", "body",
"SingleStatementExec", // drop inner local var
"SingleStatementExec", // drop outer local var
))
}

test("for statement - iterate") {
Expand All @@ -867,7 +890,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
"statement1",
"lbl1",
"statement1",
"lbl1"
"lbl1",
"SingleStatementExec", // drop local var
"SingleStatementExec", // drop local var
))
}

Expand Down Expand Up @@ -897,6 +922,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = Some("x"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("outer_body"),
new ForStatementExec(
query = TestForStatementQuery(2, "intCol1", "query2"),
variableName = Some("y"),
Expand All @@ -914,10 +940,14 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
)).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq(
"outer_body",
"body1",
"lbl1",
"outer_body",
"body1",
"lbl1"
"lbl1",
"SingleStatementExec", // drop local var
"SingleStatementExec", // drop local var
))
}

Expand Down Expand Up @@ -963,7 +993,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
)
)).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("statement1", "lbl1", "statement1", "lbl1"))
assert(statements === Seq(
"statement1", "lbl1", "statement1", "lbl1",
"SingleStatementExec", // drop local var
))
}

test("for statement no variable - leave") {
Expand All @@ -989,6 +1022,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
query = TestForStatementQuery(2, "intCol", "query1"),
variableName = None,
body = new CompoundBodyExec(Seq(
TestLeafStatement("outer_body"),
new ForStatementExec(
query = TestForStatementQuery(2, "intCol1", "query2"),
variableName = None,
Expand All @@ -1005,7 +1039,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
)
)).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("body1", "lbl1", "body1", "lbl1"))
assert(statements === Seq(
"outer_body", "body1", "lbl1", "outer_body", "body1", "lbl1",
"SingleStatementExec", // drop local var
))
}

test("for statement no variable - nested - leave outer loop") {
Expand Down
Loading

0 comments on commit 6358c11

Please sign in to comment.