diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 2e3235d6f932c..4b7b4634b74b2 100644
--- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -70,6 +70,7 @@ compoundStatement
| leaveStatement
| iterateStatement
| loopStatement
+ | forStatement
;
setStatementWithOptionalVarKeyword
@@ -111,6 +112,10 @@ loopStatement
: beginLabel? LOOP compoundBody END LOOP endLabel?
;
+forStatement
+ : beginLabel? FOR (multipartIdentifier AS)? query DO compoundBody END FOR endLabel?
+ ;
+
singleStatement
: (statement|setResetStatement) SEMICOLON* EOF
;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 6a9a97d0f5c8c..d558689a5c196 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -226,6 +226,8 @@ class AstBuilder extends DataTypeAstBuilder
visitSearchedCaseStatementImpl(searchedCaseContext, labelCtx)
case simpleCaseContext: SimpleCaseStatementContext =>
visitSimpleCaseStatementImpl(simpleCaseContext, labelCtx)
+ case forStatementContext: ForStatementContext =>
+ visitForStatementImpl(forStatementContext, labelCtx)
case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement]
}
} else {
@@ -347,28 +349,48 @@ class AstBuilder extends DataTypeAstBuilder
RepeatStatement(condition, body, Some(labelText))
}
+ private def visitForStatementImpl(
+ ctx: ForStatementContext,
+ labelCtx: SqlScriptingLabelContext): ForStatement = {
+ val labelText = labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel()))
+
+ val queryCtx = ctx.query()
+ val query = withOrigin(queryCtx) {
+ SingleStatement(visitQuery(queryCtx))
+ }
+ val varName = Option(ctx.multipartIdentifier()).map(_.getText)
+ val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx)
+ labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
+
+ ForStatement(query, varName, body, Some(labelText))
+ }
+
private def leaveOrIterateContextHasLabel(
ctx: RuleContext, label: String, isIterate: Boolean): Boolean = {
ctx match {
case c: BeginEndCompoundBlockContext
- if Option(c.beginLabel()).isDefined &&
- c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) =>
- if (isIterate) {
+ if Option(c.beginLabel()).exists { b =>
+ b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ } => if (isIterate) {
throw SqlScriptingErrors.invalidIterateLabelUsageForCompound(CurrentOrigin.get, label)
}
true
case c: WhileStatementContext
- if Option(c.beginLabel()).isDefined &&
- c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
- => true
+ if Option(c.beginLabel()).exists { b =>
+ b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ } => true
case c: RepeatStatementContext
- if Option(c.beginLabel()).isDefined &&
- c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
- => true
+ if Option(c.beginLabel()).exists { b =>
+ b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ } => true
case c: LoopStatementContext
- if Option(c.beginLabel()).isDefined &&
- c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
- => true
+ if Option(c.beginLabel()).exists { b =>
+ b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ } => true
+ case c: ForStatementContext
+ if Option(c.beginLabel()).exists { b =>
+ b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ } => true
case _ => false
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
index e6018e5e57b9c..4faf1f5d26672 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
@@ -267,3 +267,31 @@ case class LoopStatement(
LoopStatement(newChildren(0).asInstanceOf[CompoundBody], label)
}
}
+
+/**
+ * Logical operator for FOR statement.
+ * @param query Query which is executed once, then it's result set is iterated on, row by row.
+ * @param variableName Name of variable which is used to access the current row during iteration.
+ * @param body Compound body is a collection of statements that are executed for each row in
+ * the result set of the query.
+ * @param label An optional label for the loop which is unique amongst all labels for statements
+ * within which the FOR statement is contained.
+ * If an end label is specified it must match the beginning label.
+ * The label can be used to LEAVE or ITERATE the loop.
+ */
+case class ForStatement(
+ query: SingleStatement,
+ variableName: Option[String],
+ body: CompoundBody,
+ label: Option[String]) extends CompoundPlanStatement {
+
+ override def output: Seq[Attribute] = Seq.empty
+
+ override def children: Seq[LogicalPlan] = Seq(query, body)
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = newChildren match {
+ case IndexedSeq(query: SingleStatement, body: CompoundBody) =>
+ ForStatement(query, variableName, body, label)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index 3bb84f603dc67..ab647f83b42a4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery}
import org.apache.spark.sql.catalyst.plans.SQLHelper
-import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SingleStatement, WhileStatement}
+import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.internal.SQLConf
@@ -1176,7 +1176,6 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
head.asInstanceOf[SingleStatement].getText == "SELECT 42")
assert(whileStmt.label.contains("lbl"))
-
}
test("searched case statement") {
@@ -1823,6 +1822,25 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
parameters = Map("label" -> toSQLId("l_loop")))
}
+ test("unique label names: nested for loops") {
+ val sqlScriptText =
+ """BEGIN
+ |f_loop: FOR x AS SELECT 1 DO
+ | f_loop: FOR y AS SELECT 2 DO
+ | SELECT 1;
+ | END FOR;
+ |END FOR;
+ |END
+ """.stripMargin
+ val exception = intercept[SqlScriptingException] {
+ parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ }
+ checkError(
+ exception = exception,
+ condition = "LABEL_ALREADY_EXISTS",
+ parameters = Map("label" -> toSQLId("f_loop")))
+ }
+
test("unique label names: begin-end block on the same level") {
val sqlScriptText =
"""BEGIN
@@ -1858,10 +1876,13 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
| SELECT 4;
|UNTIL 1=1
|END REPEAT;
+ |lbl: FOR x AS SELECT 1 DO
+ | SELECT 5;
+ |END FOR;
|END
""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
- assert(tree.collection.length == 4)
+ assert(tree.collection.length == 5)
assert(tree.collection.head.isInstanceOf[CompoundBody])
assert(tree.collection.head.asInstanceOf[CompoundBody].label.get == "lbl")
assert(tree.collection(1).isInstanceOf[WhileStatement])
@@ -1870,6 +1891,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
assert(tree.collection(2).asInstanceOf[LoopStatement].label.get == "lbl")
assert(tree.collection(3).isInstanceOf[RepeatStatement])
assert(tree.collection(3).asInstanceOf[RepeatStatement].label.get == "lbl")
+ assert(tree.collection(4).isInstanceOf[ForStatement])
+ assert(tree.collection(4).asInstanceOf[ForStatement].label.get == "lbl")
}
test("unique label names: nested labeled scope statements") {
@@ -1879,7 +1902,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
| lbl_1: WHILE 1=1 DO
| lbl_2: LOOP
| lbl_3: REPEAT
- | SELECT 4;
+ | lbl_4: FOR x AS SELECT 1 DO
+ | SELECT 4;
+ | END FOR;
| UNTIL 1=1
| END REPEAT;
| END LOOP;
@@ -1905,6 +1930,241 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
// Repeat statement
val repeatStatement = loopStatement.body.collection.head.asInstanceOf[RepeatStatement]
assert(repeatStatement.label.get == "lbl_3")
+ // For statement
+ val forStatement = repeatStatement.body.collection.head.asInstanceOf[ForStatement]
+ assert(forStatement.label.get == "lbl_4")
+ }
+
+ test("for statement") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: FOR x AS SELECT 5 DO
+ | SELECT 1;
+ | END FOR;
+ |END""".stripMargin
+ val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[ForStatement])
+
+ val forStmt = tree.collection.head.asInstanceOf[ForStatement]
+ assert(forStmt.query.isInstanceOf[SingleStatement])
+ assert(forStmt.query.getText == "SELECT 5")
+ assert(forStmt.variableName.contains("x"))
+
+ assert(forStmt.body.isInstanceOf[CompoundBody])
+ assert(forStmt.body.collection.length == 1)
+ assert(forStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1")
+
+ assert(forStmt.label.contains("lbl"))
+ }
+
+ test("for statement - no label") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | FOR x AS SELECT 5 DO
+ | SELECT 1;
+ | END FOR;
+ |END""".stripMargin
+ val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[ForStatement])
+
+ val forStmt = tree.collection.head.asInstanceOf[ForStatement]
+ assert(forStmt.query.isInstanceOf[SingleStatement])
+ assert(forStmt.query.getText == "SELECT 5")
+ assert(forStmt.variableName.contains("x"))
+
+ assert(forStmt.body.isInstanceOf[CompoundBody])
+ assert(forStmt.body.collection.length == 1)
+ assert(forStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1")
+
+ // when not explicitly set, label is random UUID
+ assert(forStmt.label.isDefined)
+ }
+
+ test("for statement - with complex subquery") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: FOR x AS SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1 DO
+ | SELECT x.c1;
+ | SELECT x.c2;
+ | END FOR;
+ |END""".stripMargin
+ val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[ForStatement])
+
+ val forStmt = tree.collection.head.asInstanceOf[ForStatement]
+ assert(forStmt.query.isInstanceOf[SingleStatement])
+ assert(forStmt.query.getText == "SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1")
+ assert(forStmt.variableName.contains("x"))
+
+ assert(forStmt.body.isInstanceOf[CompoundBody])
+ assert(forStmt.body.collection.length == 2)
+ assert(forStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT x.c1")
+ assert(forStmt.body.collection(1).isInstanceOf[SingleStatement])
+ assert(forStmt.body.collection(1).asInstanceOf[SingleStatement].getText == "SELECT x.c2")
+
+ assert(forStmt.label.contains("lbl"))
+ }
+
+ test("for statement - nested") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl1: FOR i AS SELECT 1 DO
+ | lbl2: FOR j AS SELECT 2 DO
+ | SELECT i + j;
+ | END FOR lbl2;
+ | END FOR lbl1;
+ |END""".stripMargin
+ val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[ForStatement])
+
+ val forStmt = tree.collection.head.asInstanceOf[ForStatement]
+ assert(forStmt.query.isInstanceOf[SingleStatement])
+ assert(forStmt.query.getText == "SELECT 1")
+ assert(forStmt.variableName.contains("i"))
+ assert(forStmt.label.contains("lbl1"))
+
+ assert(forStmt.body.isInstanceOf[CompoundBody])
+ assert(forStmt.body.collection.length == 1)
+ assert(forStmt.body.collection.head.isInstanceOf[ForStatement])
+ val nestedForStmt = forStmt.body.collection.head.asInstanceOf[ForStatement]
+
+ assert(nestedForStmt.query.isInstanceOf[SingleStatement])
+ assert(nestedForStmt.query.getText == "SELECT 2")
+ assert(nestedForStmt.variableName.contains("j"))
+ assert(nestedForStmt.label.contains("lbl2"))
+
+ assert(nestedForStmt.body.isInstanceOf[CompoundBody])
+ assert(nestedForStmt.body.collection.length == 1)
+ assert(nestedForStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(nestedForStmt.body.collection.
+ head.asInstanceOf[SingleStatement].getText == "SELECT i + j")
+ }
+
+ test("for statement - no variable") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: FOR SELECT 5 DO
+ | SELECT 1;
+ | END FOR;
+ |END""".stripMargin
+ val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[ForStatement])
+
+ val forStmt = tree.collection.head.asInstanceOf[ForStatement]
+ assert(forStmt.query.isInstanceOf[SingleStatement])
+ assert(forStmt.query.getText == "SELECT 5")
+ assert(forStmt.variableName.isEmpty)
+
+ assert(forStmt.body.isInstanceOf[CompoundBody])
+ assert(forStmt.body.collection.length == 1)
+ assert(forStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1")
+
+ assert(forStmt.label.contains("lbl"))
+ }
+
+ test("for statement - no variable - no label") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | FOR SELECT 5 DO
+ | SELECT 1;
+ | END FOR;
+ |END""".stripMargin
+ val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[ForStatement])
+
+ val forStmt = tree.collection.head.asInstanceOf[ForStatement]
+ assert(forStmt.query.isInstanceOf[SingleStatement])
+ assert(forStmt.query.getText == "SELECT 5")
+ assert(forStmt.variableName.isEmpty)
+
+ assert(forStmt.body.isInstanceOf[CompoundBody])
+ assert(forStmt.body.collection.length == 1)
+ assert(forStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1")
+
+ // when not explicitly set, label is random UUID
+ assert(forStmt.label.isDefined)
+ }
+
+ test("for statement - no variable - with complex subquery") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: FOR SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1 DO
+ | SELECT 1;
+ | SELECT 2;
+ | END FOR;
+ |END""".stripMargin
+ val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[ForStatement])
+
+ val forStmt = tree.collection.head.asInstanceOf[ForStatement]
+ assert(forStmt.query.isInstanceOf[SingleStatement])
+ assert(forStmt.query.getText == "SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1")
+ assert(forStmt.variableName.isEmpty)
+
+ assert(forStmt.body.isInstanceOf[CompoundBody])
+ assert(forStmt.body.collection.length == 2)
+ assert(forStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1")
+ assert(forStmt.body.collection(1).isInstanceOf[SingleStatement])
+ assert(forStmt.body.collection(1).asInstanceOf[SingleStatement].getText == "SELECT 2")
+
+ assert(forStmt.label.contains("lbl"))
+ }
+
+ test("for statement - no variable - nested") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl1: FOR SELECT 1 DO
+ | lbl2: FOR SELECT 2 DO
+ | SELECT 3;
+ | END FOR lbl2;
+ | END FOR lbl1;
+ |END""".stripMargin
+ val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[ForStatement])
+
+ val forStmt = tree.collection.head.asInstanceOf[ForStatement]
+ assert(forStmt.query.isInstanceOf[SingleStatement])
+ assert(forStmt.query.getText == "SELECT 1")
+ assert(forStmt.variableName.isEmpty)
+ assert(forStmt.label.contains("lbl1"))
+
+ assert(forStmt.body.isInstanceOf[CompoundBody])
+ assert(forStmt.body.collection.length == 1)
+ assert(forStmt.body.collection.head.isInstanceOf[ForStatement])
+ val nestedForStmt = forStmt.body.collection.head.asInstanceOf[ForStatement]
+
+ assert(nestedForStmt.query.isInstanceOf[SingleStatement])
+ assert(nestedForStmt.query.getText == "SELECT 2")
+ assert(nestedForStmt.variableName.isEmpty)
+ assert(nestedForStmt.label.contains("lbl2"))
+
+ assert(nestedForStmt.body.isInstanceOf[CompoundBody])
+ assert(nestedForStmt.body.collection.length == 1)
+ assert(nestedForStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(nestedForStmt.body.collection.
+ head.asInstanceOf[SingleStatement].getText == "SELECT 3")
}
// Helper methods
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index 94284ec514f55..e3559e8f18ae2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -17,12 +17,14 @@
package org.apache.spark.sql.scripting
+import java.util
+
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
-import org.apache.spark.sql.catalyst.analysis.NameParameterizedQuery
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier}
+import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable}
import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
import org.apache.spark.sql.errors.SqlScriptingErrors
import org.apache.spark.sql.types.BooleanType
@@ -662,3 +664,222 @@ class LoopStatementExec(
body.reset()
}
}
+
+/**
+ * Executable node for ForStatement.
+ * @param query Executable node for the query.
+ * @param variableName Name of variable used for accessing current row during iteration.
+ * @param body Executable node for the body.
+ * @param label Label set to ForStatement by user or None otherwise.
+ * @param session Spark session that SQL script is executed within.
+ */
+class ForStatementExec(
+ query: SingleStatementExec,
+ variableName: Option[String],
+ body: CompoundBodyExec,
+ val label: Option[String],
+ session: SparkSession) extends NonLeafStatementExec {
+
+ private object ForState extends Enumeration {
+ val VariableAssignment, Body, VariableCleanup = Value
+ }
+ private var state = ForState.VariableAssignment
+ private var areVariablesDeclared = false
+
+ // map of all variables created internally by the for statement
+ // (variableName -> variableExpression)
+ private var variablesMap: Map[String, Expression] = Map()
+
+ // compound body used for dropping variables while in ForState.VariableAssignment
+ private var dropVariablesExec: CompoundBodyExec = null
+
+ private var queryResult: util.Iterator[Row] = _
+ private var isResultCacheValid = false
+ private def cachedQueryResult(): util.Iterator[Row] = {
+ if (!isResultCacheValid) {
+ queryResult = query.buildDataFrame(session).toLocalIterator()
+ query.isExecuted = true
+ isResultCacheValid = true
+ }
+ queryResult
+ }
+
+ /**
+ * For can be interrupted by LeaveStatementExec
+ */
+ private var interrupted: Boolean = false
+
+ private lazy val treeIterator: Iterator[CompoundStatementExec] =
+ new Iterator[CompoundStatementExec] {
+
+ override def hasNext: Boolean = !interrupted && (state match {
+ case ForState.VariableAssignment => cachedQueryResult().hasNext
+ case ForState.Body => true
+ case ForState.VariableCleanup => dropVariablesExec.getTreeIterator.hasNext
+ })
+
+ override def next(): CompoundStatementExec = state match {
+
+ case ForState.VariableAssignment =>
+ variablesMap = createVariablesMapFromRow(cachedQueryResult().next())
+
+ if (!areVariablesDeclared) {
+ // create and execute declare var statements
+ variablesMap.keys.toSeq
+ .map(colName => createDeclareVarExec(colName, variablesMap(colName)))
+ .foreach(declareVarExec => declareVarExec.buildDataFrame(session).collect())
+ areVariablesDeclared = true
+ }
+
+ // create and execute set var statements
+ variablesMap.keys.toSeq
+ .map(colName => createSetVarExec(colName, variablesMap(colName)))
+ .foreach(setVarExec => setVarExec.buildDataFrame(session).collect())
+
+ state = ForState.Body
+ body.reset()
+ next()
+
+ case ForState.Body =>
+ val retStmt = body.getTreeIterator.next()
+
+ // Handle LEAVE or ITERATE statement if it has been encountered.
+ retStmt match {
+ case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched =>
+ if (label.contains(leaveStatementExec.label)) {
+ leaveStatementExec.hasBeenMatched = true
+ }
+ interrupted = true
+ // If this for statement encounters LEAVE, it will either not be executed
+ // again, or it will be reset before being executed.
+ // In either case, variables will not
+ // be dropped normally, from ForState.VariableCleanup, so we drop them here.
+ dropVars()
+ return retStmt
+ case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched =>
+ if (label.contains(iterStatementExec.label)) {
+ iterStatementExec.hasBeenMatched = true
+ } else {
+ // if an outer loop is being iterated, this for statement will either not be
+ // executed again, or it will be reset before being executed.
+ // In either case, variables will not
+ // be dropped normally, from ForState.VariableCleanup, so we drop them here.
+ dropVars()
+ }
+ switchStateFromBody()
+ return retStmt
+ case _ =>
+ }
+
+ if (!body.getTreeIterator.hasNext) {
+ switchStateFromBody()
+ }
+ retStmt
+
+ case ForState.VariableCleanup =>
+ dropVariablesExec.getTreeIterator.next()
+ }
+ }
+
+ /**
+ * Recursively creates a Catalyst expression from Scala value.
+ * See https://spark.apache.org/docs/latest/sql-ref-datatypes.html for Spark -> Scala mappings
+ */
+ private def createExpressionFromValue(value: Any): Expression = value match {
+ case m: Map[_, _] =>
+ // arguments of CreateMap are in the format: (key1, val1, key2, val2, ...)
+ val mapArgs = m.keys.toSeq.flatMap { key =>
+ Seq(createExpressionFromValue(key), createExpressionFromValue(m(key)))
+ }
+ CreateMap(mapArgs, useStringTypeWhenEmpty = false)
+
+ // structs and rows match this case
+ case s: Row =>
+ // arguments of CreateNamedStruct are in the format: (name1, val1, name2, val2, ...)
+ val namedStructArgs = s.schema.names.toSeq.flatMap { colName =>
+ val valueExpression = createExpressionFromValue(s.getAs(colName))
+ Seq(Literal(colName), valueExpression)
+ }
+ CreateNamedStruct(namedStructArgs)
+
+ // arrays match this case
+ case a: collection.Seq[_] =>
+ val arrayArgs = a.toSeq.map(createExpressionFromValue(_))
+ CreateArray(arrayArgs, useStringTypeWhenEmpty = false)
+
+ case _ => Literal(value)
+ }
+
+ private def createVariablesMapFromRow(row: Row): Map[String, Expression] = {
+ var variablesMap = row.schema.names.toSeq.map { colName =>
+ colName -> createExpressionFromValue(row.getAs(colName))
+ }.toMap
+
+ if (variableName.isDefined) {
+ val namedStructArgs = variablesMap.keys.toSeq.flatMap { colName =>
+ Seq(Literal(colName), variablesMap(colName))
+ }
+ val forVariable = CreateNamedStruct(namedStructArgs)
+ variablesMap = variablesMap + (variableName.get -> forVariable)
+ }
+ variablesMap
+ }
+
+ /**
+ * Create and immediately execute dropVariable exec nodes for all variables in variablesMap.
+ */
+ private def dropVars(): Unit = {
+ variablesMap.keys.toSeq
+ .map(colName => createDropVarExec(colName))
+ .foreach(dropVarExec => dropVarExec.buildDataFrame(session).collect())
+ areVariablesDeclared = false
+ }
+
+ private def switchStateFromBody(): Unit = {
+ state = if (cachedQueryResult().hasNext) ForState.VariableAssignment
+ else {
+ // create compound body for dropping nodes after execution is complete
+ dropVariablesExec = new CompoundBodyExec(
+ variablesMap.keys.toSeq.map(colName => createDropVarExec(colName))
+ )
+ ForState.VariableCleanup
+ }
+ }
+
+ private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = {
+ val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null")
+ val declareVariable = CreateVariable(
+ UnresolvedIdentifier(Seq(varName)),
+ defaultExpression,
+ replace = true
+ )
+ new SingleStatementExec(declareVariable, Origin(), Map.empty, isInternal = true)
+ }
+
+ private def createSetVarExec(varName: String, variable: Expression): SingleStatementExec = {
+ val projectNamedStruct = Project(
+ Seq(Alias(variable, varName)()),
+ OneRowRelation()
+ )
+ val setIdentifierToCurrentRow =
+ SetVariable(Seq(UnresolvedAttribute(varName)), projectNamedStruct)
+ new SingleStatementExec(setIdentifierToCurrentRow, Origin(), Map.empty, isInternal = true)
+ }
+
+ private def createDropVarExec(varName: String): SingleStatementExec = {
+ val dropVar = DropVariable(UnresolvedIdentifier(Seq(varName)), ifExists = true)
+ new SingleStatementExec(dropVar, Origin(), Map.empty, isInternal = true)
+ }
+
+ override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+
+ override def reset(): Unit = {
+ state = ForState.VariableAssignment
+ isResultCacheValid = false
+ variablesMap = Map()
+ areVariablesDeclared = false
+ dropVariablesExec = null
+ interrupted = false
+ body.reset()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
index 387ae36b881f4..a3dc3d4599314 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.scripting
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement}
+import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.trees.Origin
/**
@@ -145,6 +145,17 @@ case class SqlScriptingInterpreter(session: SparkSession) {
.asInstanceOf[CompoundBodyExec]
new LoopStatementExec(bodyExec, label)
+ case ForStatement(query, variableNameOpt, body, label) =>
+ val queryExec =
+ new SingleStatementExec(
+ query.parsedPlan,
+ query.origin,
+ args,
+ isInternal = false)
+ val bodyExec =
+ transformTreeIntoExecutable(body, args).asInstanceOf[CompoundBodyExec]
+ new ForStatementExec(queryExec, variableNameOpt, bodyExec, label, session)
+
case leaveStatement: LeaveStatement =>
new LeaveStatementExec(leaveStatement.label)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
index 4874ea3d2795f..a997b5beadd34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
@@ -18,11 +18,12 @@
package org.apache.spark.sql.scripting
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Literal}
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, OneRowRelation, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{DropVariable, LeafNode, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
/**
* Unit tests for execution nodes from SqlScriptingExecutionNode.scala.
@@ -82,9 +83,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}
case class TestRepeat(
- condition: TestLoopCondition,
- body: CompoundBodyExec,
- label: Option[String] = None)
+ condition: TestLoopCondition,
+ body: CompoundBodyExec,
+ label: Option[String] = None)
extends RepeatStatementExec(condition, body, label, spark) {
private val evaluator = new LoopBooleanConditionEvaluator(condition)
@@ -94,6 +95,23 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
statement: LeafStatementExec): Boolean = evaluator.evaluateLoopBooleanCondition()
}
+ case class MockQuery(numberOfRows: Int, columnName: String, description: String)
+ extends SingleStatementExec(
+ DummyLogicalPlan(),
+ Origin(startIndex = Some(0), stopIndex = Some(description.length)),
+ Map.empty,
+ isInternal = false) {
+ override def buildDataFrame(session: SparkSession): DataFrame = {
+ val data = Seq.range(0, numberOfRows).map(Row(_))
+ val schema = List(StructField(columnName, IntegerType))
+
+ spark.createDataFrame(
+ spark.sparkContext.parallelize(data),
+ StructType(schema)
+ )
+ }
+ }
+
private def extractStatementValue(statement: CompoundStatementExec): String =
statement match {
case TestLeafStatement(testVal) => testVal
@@ -102,6 +120,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
case loopStmt: LoopStatementExec => loopStmt.label.get
case leaveStmt: LeaveStatementExec => leaveStmt.label
case iterateStmt: IterateStatementExec => iterateStmt.label
+ case forStmt: ForStatementExec => forStmt.label.get
+ case dropStmt: SingleStatementExec if dropStmt.parsedPlan.isInstanceOf[DropVariable]
+ => "DropVariable"
case _ => fail("Unexpected statement type")
}
@@ -688,4 +709,362 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("body1", "lbl"))
}
+
+ test("for statement - enters body once") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(1, "intCol", "query1"),
+ variableName = Some("x"),
+ label = Some("for1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(TestLeafStatement("body")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "body",
+ "DropVariable", // drop for query var intCol
+ "DropVariable" // drop for loop var x
+ ))
+ }
+
+ test("for statement - enters body with multiple statements multiple times") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = Some("x"),
+ label = Some("for1"),
+ session = spark,
+ body = new CompoundBodyExec(
+ Seq(TestLeafStatement("statement1"), TestLeafStatement("statement2"))
+ )
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "statement1",
+ "statement2",
+ "statement1",
+ "statement2",
+ "DropVariable", // drop for query var intCol
+ "DropVariable" // drop for loop var x
+ ))
+ }
+
+ test("for statement - empty result") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(0, "intCol", "query1"),
+ variableName = Some("x"),
+ label = Some("for1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(TestLeafStatement("body1")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq.empty[String])
+ }
+
+ test("for statement - nested") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = Some("x"),
+ label = Some("for1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol1", "query2"),
+ variableName = Some("y"),
+ label = Some("for2"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(TestLeafStatement("body")))
+ )
+ ))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "body",
+ "body",
+ "DropVariable", // drop for query var intCol1
+ "DropVariable", // drop for loop var y
+ "body",
+ "body",
+ "DropVariable", // drop for query var intCol1
+ "DropVariable", // drop for loop var y
+ "DropVariable", // drop for query var intCol
+ "DropVariable" // drop for loop var x
+ ))
+ }
+
+ test("for statement no variable - enters body once") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(1, "intCol", "query1"),
+ variableName = None,
+ label = Some("for1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(TestLeafStatement("body")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "body",
+ "DropVariable" // drop for query var intCol
+ ))
+ }
+
+ test("for statement no variable - enters body with multiple statements multiple times") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = None,
+ label = Some("for1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("statement1"),
+ TestLeafStatement("statement2")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "statement1", "statement2", "statement1", "statement2",
+ "DropVariable" // drop for query var intCol
+ ))
+ }
+
+ test("for statement no variable - empty result") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(0, "intCol", "query1"),
+ variableName = None,
+ label = Some("for1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(TestLeafStatement("body1")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq.empty[String])
+ }
+
+ test("for statement no variable - nested") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = None,
+ label = Some("for1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol1", "query2"),
+ variableName = None,
+ label = Some("for2"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(TestLeafStatement("body")))
+ )
+ ))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "body", "body",
+ "DropVariable", // drop for query var intCol1
+ "body", "body",
+ "DropVariable", // drop for query var intCol1
+ "DropVariable" // drop for query var intCol
+ ))
+ }
+
+ test("for statement - iterate") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = Some("x"),
+ label = Some("lbl1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("statement1"),
+ new IterateStatementExec("lbl1"),
+ TestLeafStatement("statement2")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "statement1",
+ "lbl1",
+ "statement1",
+ "lbl1",
+ "DropVariable", // drop for query var intCol
+ "DropVariable" // drop for loop var x
+ ))
+ }
+
+ test("for statement - leave") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = Some("x"),
+ label = Some("lbl1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("statement1"),
+ new LeaveStatementExec("lbl1"),
+ TestLeafStatement("statement2")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("statement1", "lbl1"))
+ }
+
+ test("for statement - nested - iterate outer loop") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = Some("x"),
+ label = Some("lbl1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("outer_body"),
+ new ForStatementExec(
+ query = MockQuery(2, "intCol1", "query2"),
+ variableName = Some("y"),
+ label = Some("lbl2"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new IterateStatementExec("lbl1"),
+ TestLeafStatement("body2")))
+ )
+ ))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "outer_body",
+ "body1",
+ "lbl1",
+ "outer_body",
+ "body1",
+ "lbl1",
+ "DropVariable", // drop for query var intCol
+ "DropVariable" // drop for loop var x
+ ))
+ }
+
+ test("for statement - nested - leave outer loop") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = Some("x"),
+ label = Some("lbl1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query2"),
+ variableName = Some("y"),
+ label = Some("lbl2"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new LeaveStatementExec("lbl1"),
+ TestLeafStatement("body2")))
+ )
+ ))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("body1", "lbl1"))
+ }
+
+ test("for statement no variable - iterate") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = None,
+ label = Some("lbl1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("statement1"),
+ new IterateStatementExec("lbl1"),
+ TestLeafStatement("statement2")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "statement1", "lbl1", "statement1", "lbl1",
+ "DropVariable" // drop for query var intCol
+ ))
+ }
+
+ test("for statement no variable - leave") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = None,
+ label = Some("lbl1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("statement1"),
+ new LeaveStatementExec("lbl1"),
+ TestLeafStatement("statement2")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("statement1", "lbl1"))
+ }
+
+ test("for statement no variable - nested - iterate outer loop") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = None,
+ label = Some("lbl1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("outer_body"),
+ new ForStatementExec(
+ query = MockQuery(2, "intCol1", "query2"),
+ variableName = None,
+ label = Some("lbl2"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new IterateStatementExec("lbl1"),
+ TestLeafStatement("body2")))
+ )
+ ))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "outer_body", "body1", "lbl1", "outer_body", "body1", "lbl1",
+ "DropVariable" // drop for query var intCol
+ ))
+ }
+
+ test("for statement no variable - nested - leave outer loop") {
+ val iter = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol", "query1"),
+ variableName = None,
+ label = Some("lbl1"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ new ForStatementExec(
+ query = MockQuery(2, "intCol1", "query2"),
+ variableName = None,
+ label = Some("lbl2"),
+ session = spark,
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new LeaveStatementExec("lbl1"),
+ TestLeafStatement("body2")))
+ )
+ ))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("body1", "lbl1"))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index 177ffc24d180a..71556c5502225 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -1550,4 +1550,1058 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
)
verifySqlScriptResult(sqlScriptText, expected)
}
+
+ test("for statement - enters body once") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet;
+ | INSERT INTO t VALUES (1, 'first', 1.0);
+ | FOR row AS SELECT * FROM t DO
+ | SELECT row.intCol;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row(1)), // select row.intCol
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - enters body with multiple statements multiple times") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet;
+ | INSERT INTO t VALUES (1, 'first', 1.0);
+ | INSERT INTO t VALUES (2, 'second', 2.0);
+ | FOR row AS SELECT * FROM t ORDER BY intCol DO
+ | SELECT row.intCol;
+ | SELECT intCol;
+ | SELECT row.stringCol;
+ | SELECT stringCol;
+ | SELECT row.doubleCol;
+ | SELECT doubleCol;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // insert
+ Seq(Row(1)), // select row.intCol
+ Seq(Row(1)), // select intCol
+ Seq(Row("first")), // select row.stringCol
+ Seq(Row("first")), // select stringCol
+ Seq(Row(1.0)), // select row.doubleCol
+ Seq(Row(1.0)), // select doubleCol
+ Seq(Row(2)), // select row.intCol
+ Seq(Row(2)), // select intCol
+ Seq(Row("second")), // select row.stringCol
+ Seq(Row("second")), // select stringCol
+ Seq(Row(2.0)), // select row.doubleCol
+ Seq(Row(2.0)), // select doubleCol
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - sum of column from table") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE sumOfCols = 0;
+ | CREATE TABLE t (intCol INT) using parquet;
+ | INSERT INTO t VALUES (1), (2), (3), (4);
+ | FOR row AS SELECT * FROM t DO
+ | SET sumOfCols = sumOfCols + row.intCol;
+ | END FOR;
+ | SELECT sumOfCols;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // declare sumOfCols
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // set sumOfCols
+ Seq.empty[Row], // set sumOfCols
+ Seq.empty[Row], // set sumOfCols
+ Seq.empty[Row], // set sumOfCols
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq(Row(10)), // select sumOfCols
+ Seq.empty[Row] // drop sumOfCols
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - map, struct, array") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (int_column INT, map_column MAP,
+ | struct_column STRUCT, array_column ARRAY);
+ | INSERT INTO t VALUES
+ | (1, MAP('a', 1), STRUCT('John', 25), ARRAY('apricot', 'quince')),
+ | (2, MAP('b', 2), STRUCT('Jane', 30), ARRAY('plum', 'pear'));
+ | FOR row AS SELECT * FROM t ORDER BY int_column DO
+ | SELECT row.map_column;
+ | SELECT map_column;
+ | SELECT row.struct_column;
+ | SELECT struct_column;
+ | SELECT row.array_column;
+ | SELECT array_column;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row(Map("a" -> 1))), // select row.map_column
+ Seq(Row(Map("a" -> 1))), // select map_column
+ Seq(Row(Row("John", 25))), // select row.struct_column
+ Seq(Row(Row("John", 25))), // select struct_column
+ Seq(Row(Array("apricot", "quince"))), // select row.array_column
+ Seq(Row(Array("apricot", "quince"))), // select array_column
+ Seq(Row(Map("b" -> 2))), // select row.map_column
+ Seq(Row(Map("b" -> 2))), // select map_column
+ Seq(Row(Row("Jane", 30))), // select row.struct_column
+ Seq(Row(Row("Jane", 30))), // select struct_column
+ Seq(Row(Array("plum", "pear"))), // select row.array_column
+ Seq(Row(Array("plum", "pear"))), // select array_column
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - nested struct") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t
+ | (int_column INT,
+ | struct_column STRUCT>>);
+ | INSERT INTO t VALUES
+ | (1, STRUCT(1, STRUCT(STRUCT("one")))),
+ | (2, STRUCT(2, STRUCT(STRUCT("two"))));
+ | FOR row AS SELECT * FROM t ORDER BY int_column DO
+ | SELECT row.struct_column;
+ | SELECT struct_column;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row(Row(1, Row(Row("one"))))), // select row.struct_column
+ Seq(Row(Row(1, Row(Row("one"))))), // select struct_column
+ Seq(Row(Row(2, Row(Row("two"))))), // select row.struct_column
+ Seq(Row(Row(2, Row(Row("two"))))), // select struct_column
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - nested map") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (int_column INT, map_column MAP>>);
+ | INSERT INTO t VALUES
+ | (1, MAP('a', MAP(1, MAP(false, 10)))),
+ | (2, MAP('b', MAP(2, MAP(true, 20))));
+ | FOR row AS SELECT * FROM t ORDER BY int_column DO
+ | SELECT row.map_column;
+ | SELECT map_column;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select row.map_column
+ Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select map_column
+ Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select row.map_column
+ Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select map_column
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - nested array") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t
+ | (int_column INT, array_column ARRAY>>);
+ | INSERT INTO t VALUES
+ | (1, ARRAY(ARRAY(ARRAY(1, 2), ARRAY(3, 4)), ARRAY(ARRAY(5, 6)))),
+ | (2, ARRAY(ARRAY(ARRAY(7, 8), ARRAY(9, 10)), ARRAY(ARRAY(11, 12))));
+ | FOR row AS SELECT * FROM t ORDER BY int_column DO
+ | SELECT row.array_column;
+ | SELECT array_column;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // row.array_column
+ Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // array_column
+ Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // row.array_column
+ Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // array_column
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement empty result") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | FOR row AS SELECT * FROM t ORDER BY intCol DO
+ | SELECT row.intCol;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row] // create table
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement iterate") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT, stringCol STRING) using parquet;
+ | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth');
+ |
+ | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO
+ | IF x.intCol = 2 THEN
+ | ITERATE lbl;
+ | END IF;
+ | SELECT stringCol;
+ | SELECT x.stringCol;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row("first")), // select stringCol
+ Seq(Row("first")), // select x.stringCol
+ Seq(Row("third")), // select stringCol
+ Seq(Row("third")), // select x.stringCol
+ Seq(Row("fourth")), // select stringCol
+ Seq(Row("fourth")), // select x.stringCol
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement leave") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT, stringCol STRING) using parquet;
+ | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth');
+ |
+ | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO
+ | IF x.intCol = 3 THEN
+ | LEAVE lbl;
+ | END IF;
+ | SELECT stringCol;
+ | SELECT x.stringCol;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row("first")), // select stringCol
+ Seq(Row("first")), // select x.stringCol
+ Seq(Row("second")), // select stringCol
+ Seq(Row("second")) // select x.stringCol
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - nested - in while") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE cnt = 0;
+ | CREATE TABLE t (intCol INT) using parquet;
+ | INSERT INTO t VALUES (0);
+ | WHILE cnt < 2 DO
+ | SET cnt = cnt + 1;
+ | FOR x AS SELECT * FROM t ORDER BY intCol DO
+ | SELECT x.intCol;
+ | END FOR;
+ | INSERT INTO t VALUES (cnt);
+ | END WHILE;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // declare cnt
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // set cnt
+ Seq(Row(0)), // select intCol
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // set cnt
+ Seq(Row(0)), // select intCol
+ Seq(Row(1)), // select intCol
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop cnt
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - nested - in other for") {
+ withTable("t", "t2") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | CREATE TABLE t2 (intCol2 INT) using parquet;
+ | INSERT INTO t VALUES (0), (1);
+ | INSERT INTO t2 VALUES (2), (3);
+ | FOR x as SELECT * FROM t ORDER BY intCol DO
+ | FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO
+ | SELECT x.intCol;
+ | SELECT intCol;
+ | SELECT y.intCol2;
+ | SELECT intCol2;
+ | END FOR;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // insert
+ Seq(Row(0)), // select x.intCol
+ Seq(Row(0)), // select intCol
+ Seq(Row(3)), // select y.intCol2
+ Seq(Row(3)), // select intCol2
+ Seq(Row(0)), // select x.intCol
+ Seq(Row(0)), // select intCol
+ Seq(Row(2)), // select y.intCol2
+ Seq(Row(2)), // select intCol2
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq(Row(1)), // select x.intCol
+ Seq(Row(1)), // select intCol
+ Seq(Row(3)), // select y.intCol2
+ Seq(Row(3)), // select intCol2
+ Seq(Row(1)), // select x.intCol
+ Seq(Row(1)), // select intCol
+ Seq(Row(2)), // select y.intCol2
+ Seq(Row(2)), // select intCol2
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop outer var
+ Seq.empty[Row] // drop outer var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ // ignored until loops are fixed to support empty bodies
+ ignore("for statement - nested - empty result set") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | REPEAT
+ | FOR x AS SELECT * FROM t ORDER BY intCol DO
+ | SELECT x.intCol;
+ | END FOR;
+ | UNTIL 1 = 1
+ | END REPEAT;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // declare cnt
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // set cnt
+ Seq(Row(0)), // select intCol
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // set cnt
+ Seq(Row(0)), // select intCol
+ Seq(Row(1)), // select intCol
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop cnt
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - nested - iterate outer loop") {
+ withTable("t", "t2") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | CREATE TABLE t2 (intCol2 INT) using parquet;
+ | INSERT INTO t VALUES (0), (1);
+ | INSERT INTO t2 VALUES (2), (3);
+ | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO
+ | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO
+ | SELECT y.intCol2;
+ | SELECT intCol2;
+ | ITERATE lbl1;
+ | SELECT 1;
+ | END FOR;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // insert
+ Seq(Row(3)), // select y.intCol2
+ Seq(Row(3)), // select intCol2
+ Seq(Row(3)), // select y.intCol2
+ Seq(Row(3)), // select intCol2
+ Seq.empty[Row], // drop outer var
+ Seq.empty[Row] // drop outer var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - nested - leave outer loop") {
+ withTable("t", "t2") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | CREATE TABLE t2 (intCol2 INT) using parquet;
+ | INSERT INTO t VALUES (0), (1);
+ | INSERT INTO t2 VALUES (2), (3);
+ | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO
+ | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO
+ | SELECT y.intCol2;
+ | SELECT intCol2;
+ | LEAVE lbl1;
+ | SELECT 1;
+ | END FOR;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // insert
+ Seq(Row(3)), // select y.intCol2
+ Seq(Row(3)) // select intCol2
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - nested - leave inner loop") {
+ withTable("t", "t2") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | CREATE TABLE t2 (intCol2 INT) using parquet;
+ | INSERT INTO t VALUES (0), (1);
+ | INSERT INTO t2 VALUES (2), (3);
+ | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO
+ | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO
+ | SELECT y.intCol2;
+ | SELECT intCol2;
+ | LEAVE lbl2;
+ | SELECT 1;
+ | END FOR;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // insert
+ Seq(Row(3)), // select y.intCol2
+ Seq(Row(3)), // select intCol2
+ Seq(Row(3)), // select y.intCol2
+ Seq(Row(3)), // select intCol2
+ Seq.empty[Row], // drop outer var
+ Seq.empty[Row] // drop outer var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - enters body once") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet;
+ | INSERT INTO t VALUES (1, 'first', 1.0);
+ | FOR SELECT * FROM t DO
+ | SELECT intCol;
+ | SELECT stringCol;
+ | SELECT doubleCol;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row(1)), // select intCol
+ Seq(Row("first")), // select stringCol
+ Seq(Row(1.0)), // select doubleCol
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - enters body with multiple statements multiple times") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet;
+ | INSERT INTO t VALUES (1, 'first', 1.0);
+ | INSERT INTO t VALUES (2, 'second', 2.0);
+ | FOR SELECT * FROM t ORDER BY intCol DO
+ | SELECT intCol;
+ | SELECT stringCol;
+ | SELECT doubleCol;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // insert
+ Seq(Row(1)), // select intCol
+ Seq(Row("first")), // select stringCol
+ Seq(Row(1.0)), // select doubleCol
+ Seq(Row(2)), // select intCol
+ Seq(Row("second")), // select stringCol
+ Seq(Row(2.0)), // select doubleCol
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - sum of column from table") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE sumOfCols = 0;
+ | CREATE TABLE t (intCol INT) using parquet;
+ | INSERT INTO t VALUES (1), (2), (3), (4);
+ | FOR SELECT * FROM t DO
+ | SET sumOfCols = sumOfCols + intCol;
+ | END FOR;
+ | SELECT sumOfCols;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // declare sumOfCols
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // set sumOfCols
+ Seq.empty[Row], // set sumOfCols
+ Seq.empty[Row], // set sumOfCols
+ Seq.empty[Row], // set sumOfCols
+ Seq.empty[Row], // drop local var
+ Seq(Row(10)), // select sumOfCols
+ Seq.empty[Row] // drop sumOfCols
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - map, struct, array") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (int_column INT, map_column MAP,
+ | struct_column STRUCT, array_column ARRAY);
+ | INSERT INTO t VALUES
+ | (1, MAP('a', 1), STRUCT('John', 25), ARRAY('apricot', 'quince')),
+ | (2, MAP('b', 2), STRUCT('Jane', 30), ARRAY('plum', 'pear'));
+ | FOR SELECT * FROM t ORDER BY int_column DO
+ | SELECT map_column;
+ | SELECT struct_column;
+ | SELECT array_column;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row(Map("a" -> 1))), // select map_column
+ Seq(Row(Row("John", 25))), // select struct_column
+ Seq(Row(Array("apricot", "quince"))), // select array_column
+ Seq(Row(Map("b" -> 2))), // select map_column
+ Seq(Row(Row("Jane", 30))), // select struct_column
+ Seq(Row(Array("plum", "pear"))), // select array_column
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - nested struct") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (int_column INT,
+ | struct_column STRUCT>>);
+ | INSERT INTO t VALUES
+ | (1, STRUCT(1, STRUCT(STRUCT("one")))),
+ | (2, STRUCT(2, STRUCT(STRUCT("two"))));
+ | FOR SELECT * FROM t ORDER BY int_column DO
+ | SELECT struct_column;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row(Row(1, Row(Row("one"))))), // select struct_column
+ Seq(Row(Row(2, Row(Row("two"))))), // select struct_column
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - nested map") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (int_column INT, map_column MAP>>);
+ | INSERT INTO t VALUES
+ | (1, MAP('a', MAP(1, MAP(false, 10)))),
+ | (2, MAP('b', MAP(2, MAP(true, 20))));
+ | FOR SELECT * FROM t ORDER BY int_column DO
+ | SELECT map_column;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select map_column
+ Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select map_column
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - nested array") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t
+ | (int_column INT, array_column ARRAY>>);
+ | INSERT INTO t VALUES
+ | (1, ARRAY(ARRAY(ARRAY(1, 2), ARRAY(3, 4)), ARRAY(ARRAY(5, 6)))),
+ | (2, ARRAY(ARRAY(ARRAY(7, 8), ARRAY(9, 10)), ARRAY(ARRAY(11, 12))));
+ | FOR SELECT * FROM t ORDER BY int_column DO
+ | SELECT array_column;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // array_column
+ Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // array_column
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - empty result") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | FOR SELECT * FROM t ORDER BY intCol DO
+ | SELECT intCol;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row] // create table
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - iterate") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT, stringCol STRING) using parquet;
+ | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth');
+ |
+ | lbl: FOR SELECT * FROM t ORDER BY intCol DO
+ | IF intCol = 2 THEN
+ | ITERATE lbl;
+ | END IF;
+ | SELECT stringCol;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row("first")), // select stringCol
+ Seq(Row("third")), // select stringCol
+ Seq(Row("fourth")), // select stringCol
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop local var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - leave") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT, stringCol STRING) using parquet;
+ | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth');
+ |
+ | lbl: FOR SELECT * FROM t ORDER BY intCol DO
+ | IF intCol = 3 THEN
+ | LEAVE lbl;
+ | END IF;
+ | SELECT stringCol;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq(Row("first")), // select stringCol
+ Seq(Row("second")) // select stringCol
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - nested - in while") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE cnt = 0;
+ | CREATE TABLE t (intCol INT) using parquet;
+ | INSERT INTO t VALUES (0);
+ | WHILE cnt < 2 DO
+ | SET cnt = cnt + 1;
+ | FOR SELECT * FROM t ORDER BY intCol DO
+ | SELECT intCol;
+ | END FOR;
+ | INSERT INTO t VALUES (cnt);
+ | END WHILE;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // declare cnt
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // set cnt
+ Seq(Row(0)), // select intCol
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // set cnt
+ Seq(Row(0)), // select intCol
+ Seq(Row(1)), // select intCol
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop cnt
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - nested - in other for") {
+ withTable("t", "t2") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | CREATE TABLE t2 (intCol2 INT) using parquet;
+ | INSERT INTO t VALUES (0), (1);
+ | INSERT INTO t2 VALUES (2), (3);
+ | FOR SELECT * FROM t ORDER BY intCol DO
+ | FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO
+ | SELECT intCol;
+ | SELECT intCol2;
+ | END FOR;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // insert
+ Seq(Row(0)), // select intCol
+ Seq(Row(3)), // select intCol2
+ Seq(Row(0)), // select intCol
+ Seq(Row(2)), // select intCol2
+ Seq.empty[Row], // drop local var
+ Seq(Row(1)), // select intCol
+ Seq(Row(3)), // select intCol2
+ Seq(Row(1)), // select intCol
+ Seq(Row(2)), // select intCol2
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop outer var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ // ignored until loops are fixed to support empty bodies
+ ignore("for statement - no variable - nested - empty result set") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | REPEAT
+ | FOR SELECT * FROM t ORDER BY intCol DO
+ | SELECT intCol;
+ | END FOR;
+ | UNTIL 1 = 1
+ | END REPEAT;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // declare cnt
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // set cnt
+ Seq(Row(0)), // select intCol
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row], // set cnt
+ Seq(Row(0)), // select intCol
+ Seq(Row(1)), // select intCol
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // drop local var
+ Seq.empty[Row] // drop cnt
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - nested - iterate outer loop") {
+ withTable("t", "t2") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | CREATE TABLE t2 (intCol2 INT) using parquet;
+ | INSERT INTO t VALUES (0), (1);
+ | INSERT INTO t2 VALUES (2), (3);
+ | lbl1: FOR SELECT * FROM t ORDER BY intCol DO
+ | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO
+ | SELECT intCol2;
+ | ITERATE lbl1;
+ | SELECT 1;
+ | END FOR;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // insert
+ Seq(Row(3)), // select intCol2
+ Seq(Row(3)), // select intCol2
+ Seq.empty[Row] // drop outer var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - nested - leave outer loop") {
+ withTable("t", "t2") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | CREATE TABLE t2 (intCol2 INT) using parquet;
+ | INSERT INTO t VALUES (0), (1);
+ | INSERT INTO t2 VALUES (2), (3);
+ | lbl1: FOR SELECT * FROM t ORDER BY intCol DO
+ | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO
+ | SELECT intCol2;
+ | LEAVE lbl1;
+ | SELECT 1;
+ | END FOR;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // insert
+ Seq(Row(3)) // select intCol2
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
+
+ test("for statement - no variable - nested - leave inner loop") {
+ withTable("t", "t2") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (intCol INT) using parquet;
+ | CREATE TABLE t2 (intCol2 INT) using parquet;
+ | INSERT INTO t VALUES (0), (1);
+ | INSERT INTO t2 VALUES (2), (3);
+ | lbl1: FOR SELECT * FROM t ORDER BY intCol DO
+ | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO
+ | SELECT intCol2;
+ | LEAVE lbl2;
+ | SELECT 1;
+ | END FOR;
+ | END FOR;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // create table
+ Seq.empty[Row], // insert
+ Seq.empty[Row], // insert
+ Seq(Row(3)), // select intCol2
+ Seq(Row(3)), // select intCol2
+ Seq.empty[Row] // drop outer var
+ )
+ verifySqlScriptResult(sqlScript, expected)
+ }
+ }
}