From 2c2c0e05512731e73fda98715551c021f12e24e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 28 Nov 2024 20:47:12 +0800 Subject: [PATCH] [SPARK-48356][SQL] Support for FOR statement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? In this PR, support for FOR statement in SQL scripting is introduced. Examples: ``` FOR row AS SELECT * FROM t DO SELECT row.intCol; END FOR; ``` ``` FOR SELECT * FROM t DO SELECT intCol; END FOR; ``` Implementation notes: As local variables for SQL scripting are currently a work in progress, session variables are used to simulate them. When FOR begins executing, session variables are declared for each column in the result set, and optionally for the for variable if it is present ("row" in the example above). On each iteration, these variables are overwritten with the values from the row currently being iterated. The variables are dropped upon loop completion. This means that if a session variable which matches the name of a column in the result set already exists, the for statement will drop that variable after completion. If that variable would be referenced after the for statement, the script would fail as the variable would not exist. This limitation is already present in the current iteration of SQL scripting, and will be fixed once local variables are introduced. Also, with local variables the implementation of for statement will be much simpler. Grammar/parser changes: `forStatement` grammar rule `visitForStatement` rule visitor `ForStatement` logical operator ### Why are the changes needed? FOR statement is an part of SQL scripting control flow logic. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New tests are introduced to all of the three scripting test suites: `SqlScriptingParserSuite`, `SqlScriptingExecutionNodeSuite` and `SqlScriptingInterpreterSuite`. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48794 from dusantism-db/scripting-for-loop. Authored-by: Dušan Tišma Signed-off-by: Wenchen Fan --- .../sql/catalyst/parser/SqlBaseParser.g4 | 5 + .../sql/catalyst/parser/AstBuilder.scala | 46 +- .../logical/SqlScriptingLogicalPlans.scala | 28 + .../parser/SqlScriptingParserSuite.scala | 268 ++++- .../scripting/SqlScriptingExecutionNode.scala | 229 +++- .../scripting/SqlScriptingInterpreter.scala | 13 +- .../SqlScriptingExecutionNodeSuite.scala | 389 +++++- .../SqlScriptingInterpreterSuite.scala | 1054 +++++++++++++++++ 8 files changed, 2006 insertions(+), 26 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 2e3235d6f932c..4b7b4634b74b2 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -70,6 +70,7 @@ compoundStatement | leaveStatement | iterateStatement | loopStatement + | forStatement ; setStatementWithOptionalVarKeyword @@ -111,6 +112,10 @@ loopStatement : beginLabel? LOOP compoundBody END LOOP endLabel? ; +forStatement + : beginLabel? FOR (multipartIdentifier AS)? query DO compoundBody END FOR endLabel? + ; + singleStatement : (statement|setResetStatement) SEMICOLON* EOF ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 6a9a97d0f5c8c..d558689a5c196 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -226,6 +226,8 @@ class AstBuilder extends DataTypeAstBuilder visitSearchedCaseStatementImpl(searchedCaseContext, labelCtx) case simpleCaseContext: SimpleCaseStatementContext => visitSimpleCaseStatementImpl(simpleCaseContext, labelCtx) + case forStatementContext: ForStatementContext => + visitForStatementImpl(forStatementContext, labelCtx) case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement] } } else { @@ -347,28 +349,48 @@ class AstBuilder extends DataTypeAstBuilder RepeatStatement(condition, body, Some(labelText)) } + private def visitForStatementImpl( + ctx: ForStatementContext, + labelCtx: SqlScriptingLabelContext): ForStatement = { + val labelText = labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel())) + + val queryCtx = ctx.query() + val query = withOrigin(queryCtx) { + SingleStatement(visitQuery(queryCtx)) + } + val varName = Option(ctx.multipartIdentifier()).map(_.getText) + val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx) + labelCtx.exitLabeledScope(Option(ctx.beginLabel())) + + ForStatement(query, varName, body, Some(labelText)) + } + private def leaveOrIterateContextHasLabel( ctx: RuleContext, label: String, isIterate: Boolean): Boolean = { ctx match { case c: BeginEndCompoundBlockContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) => - if (isIterate) { + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => if (isIterate) { throw SqlScriptingErrors.invalidIterateLabelUsageForCompound(CurrentOrigin.get, label) } true case c: WhileStatementContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true case c: RepeatStatementContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true case c: LoopStatementContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true + case c: ForStatementContext + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true case _ => false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index e6018e5e57b9c..4faf1f5d26672 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -267,3 +267,31 @@ case class LoopStatement( LoopStatement(newChildren(0).asInstanceOf[CompoundBody], label) } } + +/** + * Logical operator for FOR statement. + * @param query Query which is executed once, then it's result set is iterated on, row by row. + * @param variableName Name of variable which is used to access the current row during iteration. + * @param body Compound body is a collection of statements that are executed for each row in + * the result set of the query. + * @param label An optional label for the loop which is unique amongst all labels for statements + * within which the FOR statement is contained. + * If an end label is specified it must match the beginning label. + * The label can be used to LEAVE or ITERATE the loop. + */ +case class ForStatement( + query: SingleStatement, + variableName: Option[String], + body: CompoundBody, + label: Option[String]) extends CompoundPlanStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq(query, body) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = newChildren match { + case IndexedSeq(query: SingleStatement, body: CompoundBody) => + ForStatement(query, variableName, body, label) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 3bb84f603dc67..ab647f83b42a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf @@ -1176,7 +1176,6 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { head.asInstanceOf[SingleStatement].getText == "SELECT 42") assert(whileStmt.label.contains("lbl")) - } test("searched case statement") { @@ -1823,6 +1822,25 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { parameters = Map("label" -> toSQLId("l_loop"))) } + test("unique label names: nested for loops") { + val sqlScriptText = + """BEGIN + |f_loop: FOR x AS SELECT 1 DO + | f_loop: FOR y AS SELECT 2 DO + | SELECT 1; + | END FOR; + |END FOR; + |END + """.stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + } + checkError( + exception = exception, + condition = "LABEL_ALREADY_EXISTS", + parameters = Map("label" -> toSQLId("f_loop"))) + } + test("unique label names: begin-end block on the same level") { val sqlScriptText = """BEGIN @@ -1858,10 +1876,13 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT 4; |UNTIL 1=1 |END REPEAT; + |lbl: FOR x AS SELECT 1 DO + | SELECT 5; + |END FOR; |END """.stripMargin val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] - assert(tree.collection.length == 4) + assert(tree.collection.length == 5) assert(tree.collection.head.isInstanceOf[CompoundBody]) assert(tree.collection.head.asInstanceOf[CompoundBody].label.get == "lbl") assert(tree.collection(1).isInstanceOf[WhileStatement]) @@ -1870,6 +1891,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection(2).asInstanceOf[LoopStatement].label.get == "lbl") assert(tree.collection(3).isInstanceOf[RepeatStatement]) assert(tree.collection(3).asInstanceOf[RepeatStatement].label.get == "lbl") + assert(tree.collection(4).isInstanceOf[ForStatement]) + assert(tree.collection(4).asInstanceOf[ForStatement].label.get == "lbl") } test("unique label names: nested labeled scope statements") { @@ -1879,7 +1902,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | lbl_1: WHILE 1=1 DO | lbl_2: LOOP | lbl_3: REPEAT - | SELECT 4; + | lbl_4: FOR x AS SELECT 1 DO + | SELECT 4; + | END FOR; | UNTIL 1=1 | END REPEAT; | END LOOP; @@ -1905,6 +1930,241 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { // Repeat statement val repeatStatement = loopStatement.body.collection.head.asInstanceOf[RepeatStatement] assert(repeatStatement.label.get == "lbl_3") + // For statement + val forStatement = repeatStatement.body.collection.head.asInstanceOf[ForStatement] + assert(forStatement.label.get == "lbl_4") + } + + test("for statement") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR x AS SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.contains("x")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(forStmt.label.contains("lbl")) + } + + test("for statement - no label") { + val sqlScriptText = + """ + |BEGIN + | FOR x AS SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.contains("x")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + // when not explicitly set, label is random UUID + assert(forStmt.label.isDefined) + } + + test("for statement - with complex subquery") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR x AS SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1 DO + | SELECT x.c1; + | SELECT x.c2; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1") + assert(forStmt.variableName.contains("x")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 2) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT x.c1") + assert(forStmt.body.collection(1).isInstanceOf[SingleStatement]) + assert(forStmt.body.collection(1).asInstanceOf[SingleStatement].getText == "SELECT x.c2") + + assert(forStmt.label.contains("lbl")) + } + + test("for statement - nested") { + val sqlScriptText = + """ + |BEGIN + | lbl1: FOR i AS SELECT 1 DO + | lbl2: FOR j AS SELECT 2 DO + | SELECT i + j; + | END FOR lbl2; + | END FOR lbl1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 1") + assert(forStmt.variableName.contains("i")) + assert(forStmt.label.contains("lbl1")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[ForStatement]) + val nestedForStmt = forStmt.body.collection.head.asInstanceOf[ForStatement] + + assert(nestedForStmt.query.isInstanceOf[SingleStatement]) + assert(nestedForStmt.query.getText == "SELECT 2") + assert(nestedForStmt.variableName.contains("j")) + assert(nestedForStmt.label.contains("lbl2")) + + assert(nestedForStmt.body.isInstanceOf[CompoundBody]) + assert(nestedForStmt.body.collection.length == 1) + assert(nestedForStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedForStmt.body.collection. + head.asInstanceOf[SingleStatement].getText == "SELECT i + j") + } + + test("for statement - no variable") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.isEmpty) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(forStmt.label.contains("lbl")) + } + + test("for statement - no variable - no label") { + val sqlScriptText = + """ + |BEGIN + | FOR SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.isEmpty) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + // when not explicitly set, label is random UUID + assert(forStmt.label.isDefined) + } + + test("for statement - no variable - with complex subquery") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1 DO + | SELECT 1; + | SELECT 2; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1") + assert(forStmt.variableName.isEmpty) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 2) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + assert(forStmt.body.collection(1).isInstanceOf[SingleStatement]) + assert(forStmt.body.collection(1).asInstanceOf[SingleStatement].getText == "SELECT 2") + + assert(forStmt.label.contains("lbl")) + } + + test("for statement - no variable - nested") { + val sqlScriptText = + """ + |BEGIN + | lbl1: FOR SELECT 1 DO + | lbl2: FOR SELECT 2 DO + | SELECT 3; + | END FOR lbl2; + | END FOR lbl1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 1") + assert(forStmt.variableName.isEmpty) + assert(forStmt.label.contains("lbl1")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[ForStatement]) + val nestedForStmt = forStmt.body.collection.head.asInstanceOf[ForStatement] + + assert(nestedForStmt.query.isInstanceOf[SingleStatement]) + assert(nestedForStmt.query.getText == "SELECT 2") + assert(nestedForStmt.variableName.isEmpty) + assert(nestedForStmt.label.contains("lbl2")) + + assert(nestedForStmt.body.isInstanceOf[CompoundBody]) + assert(nestedForStmt.body.collection.length == 1) + assert(nestedForStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedForStmt.body.collection. + head.asInstanceOf[SingleStatement].getText == "SELECT 3") } // Helper methods diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 94284ec514f55..e3559e8f18ae2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.scripting +import java.util + import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.catalyst.analysis.NameParameterizedQuery -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier} +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType @@ -662,3 +664,222 @@ class LoopStatementExec( body.reset() } } + +/** + * Executable node for ForStatement. + * @param query Executable node for the query. + * @param variableName Name of variable used for accessing current row during iteration. + * @param body Executable node for the body. + * @param label Label set to ForStatement by user or None otherwise. + * @param session Spark session that SQL script is executed within. + */ +class ForStatementExec( + query: SingleStatementExec, + variableName: Option[String], + body: CompoundBodyExec, + val label: Option[String], + session: SparkSession) extends NonLeafStatementExec { + + private object ForState extends Enumeration { + val VariableAssignment, Body, VariableCleanup = Value + } + private var state = ForState.VariableAssignment + private var areVariablesDeclared = false + + // map of all variables created internally by the for statement + // (variableName -> variableExpression) + private var variablesMap: Map[String, Expression] = Map() + + // compound body used for dropping variables while in ForState.VariableAssignment + private var dropVariablesExec: CompoundBodyExec = null + + private var queryResult: util.Iterator[Row] = _ + private var isResultCacheValid = false + private def cachedQueryResult(): util.Iterator[Row] = { + if (!isResultCacheValid) { + queryResult = query.buildDataFrame(session).toLocalIterator() + query.isExecuted = true + isResultCacheValid = true + } + queryResult + } + + /** + * For can be interrupted by LeaveStatementExec + */ + private var interrupted: Boolean = false + + private lazy val treeIterator: Iterator[CompoundStatementExec] = + new Iterator[CompoundStatementExec] { + + override def hasNext: Boolean = !interrupted && (state match { + case ForState.VariableAssignment => cachedQueryResult().hasNext + case ForState.Body => true + case ForState.VariableCleanup => dropVariablesExec.getTreeIterator.hasNext + }) + + override def next(): CompoundStatementExec = state match { + + case ForState.VariableAssignment => + variablesMap = createVariablesMapFromRow(cachedQueryResult().next()) + + if (!areVariablesDeclared) { + // create and execute declare var statements + variablesMap.keys.toSeq + .map(colName => createDeclareVarExec(colName, variablesMap(colName))) + .foreach(declareVarExec => declareVarExec.buildDataFrame(session).collect()) + areVariablesDeclared = true + } + + // create and execute set var statements + variablesMap.keys.toSeq + .map(colName => createSetVarExec(colName, variablesMap(colName))) + .foreach(setVarExec => setVarExec.buildDataFrame(session).collect()) + + state = ForState.Body + body.reset() + next() + + case ForState.Body => + val retStmt = body.getTreeIterator.next() + + // Handle LEAVE or ITERATE statement if it has been encountered. + retStmt match { + case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched => + if (label.contains(leaveStatementExec.label)) { + leaveStatementExec.hasBeenMatched = true + } + interrupted = true + // If this for statement encounters LEAVE, it will either not be executed + // again, or it will be reset before being executed. + // In either case, variables will not + // be dropped normally, from ForState.VariableCleanup, so we drop them here. + dropVars() + return retStmt + case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched => + if (label.contains(iterStatementExec.label)) { + iterStatementExec.hasBeenMatched = true + } else { + // if an outer loop is being iterated, this for statement will either not be + // executed again, or it will be reset before being executed. + // In either case, variables will not + // be dropped normally, from ForState.VariableCleanup, so we drop them here. + dropVars() + } + switchStateFromBody() + return retStmt + case _ => + } + + if (!body.getTreeIterator.hasNext) { + switchStateFromBody() + } + retStmt + + case ForState.VariableCleanup => + dropVariablesExec.getTreeIterator.next() + } + } + + /** + * Recursively creates a Catalyst expression from Scala value.
+ * See https://spark.apache.org/docs/latest/sql-ref-datatypes.html for Spark -> Scala mappings + */ + private def createExpressionFromValue(value: Any): Expression = value match { + case m: Map[_, _] => + // arguments of CreateMap are in the format: (key1, val1, key2, val2, ...) + val mapArgs = m.keys.toSeq.flatMap { key => + Seq(createExpressionFromValue(key), createExpressionFromValue(m(key))) + } + CreateMap(mapArgs, useStringTypeWhenEmpty = false) + + // structs and rows match this case + case s: Row => + // arguments of CreateNamedStruct are in the format: (name1, val1, name2, val2, ...) + val namedStructArgs = s.schema.names.toSeq.flatMap { colName => + val valueExpression = createExpressionFromValue(s.getAs(colName)) + Seq(Literal(colName), valueExpression) + } + CreateNamedStruct(namedStructArgs) + + // arrays match this case + case a: collection.Seq[_] => + val arrayArgs = a.toSeq.map(createExpressionFromValue(_)) + CreateArray(arrayArgs, useStringTypeWhenEmpty = false) + + case _ => Literal(value) + } + + private def createVariablesMapFromRow(row: Row): Map[String, Expression] = { + var variablesMap = row.schema.names.toSeq.map { colName => + colName -> createExpressionFromValue(row.getAs(colName)) + }.toMap + + if (variableName.isDefined) { + val namedStructArgs = variablesMap.keys.toSeq.flatMap { colName => + Seq(Literal(colName), variablesMap(colName)) + } + val forVariable = CreateNamedStruct(namedStructArgs) + variablesMap = variablesMap + (variableName.get -> forVariable) + } + variablesMap + } + + /** + * Create and immediately execute dropVariable exec nodes for all variables in variablesMap. + */ + private def dropVars(): Unit = { + variablesMap.keys.toSeq + .map(colName => createDropVarExec(colName)) + .foreach(dropVarExec => dropVarExec.buildDataFrame(session).collect()) + areVariablesDeclared = false + } + + private def switchStateFromBody(): Unit = { + state = if (cachedQueryResult().hasNext) ForState.VariableAssignment + else { + // create compound body for dropping nodes after execution is complete + dropVariablesExec = new CompoundBodyExec( + variablesMap.keys.toSeq.map(colName => createDropVarExec(colName)) + ) + ForState.VariableCleanup + } + } + + private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = { + val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null") + val declareVariable = CreateVariable( + UnresolvedIdentifier(Seq(varName)), + defaultExpression, + replace = true + ) + new SingleStatementExec(declareVariable, Origin(), Map.empty, isInternal = true) + } + + private def createSetVarExec(varName: String, variable: Expression): SingleStatementExec = { + val projectNamedStruct = Project( + Seq(Alias(variable, varName)()), + OneRowRelation() + ) + val setIdentifierToCurrentRow = + SetVariable(Seq(UnresolvedAttribute(varName)), projectNamedStruct) + new SingleStatementExec(setIdentifierToCurrentRow, Origin(), Map.empty, isInternal = true) + } + + private def createDropVarExec(varName: String): SingleStatementExec = { + val dropVar = DropVariable(UnresolvedIdentifier(Seq(varName)), ifExists = true) + new SingleStatementExec(dropVar, Origin(), Map.empty, isInternal = true) + } + + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator + + override def reset(): Unit = { + state = ForState.VariableAssignment + isResultCacheValid = false + variablesMap = Map() + areVariablesDeclared = false + dropVariablesExec = null + interrupted = false + body.reset() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 387ae36b881f4..a3dc3d4599314 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.catalyst.trees.Origin /** @@ -145,6 +145,17 @@ case class SqlScriptingInterpreter(session: SparkSession) { .asInstanceOf[CompoundBodyExec] new LoopStatementExec(bodyExec, label) + case ForStatement(query, variableNameOpt, body, label) => + val queryExec = + new SingleStatementExec( + query.parsedPlan, + query.origin, + args, + isInternal = false) + val bodyExec = + transformTreeIntoExecutable(body, args).asInstanceOf[CompoundBodyExec] + new ForStatementExec(queryExec, variableNameOpt, bodyExec, label, session) + case leaveStatement: LeaveStatement => new LeaveStatementExec(leaveStatement.label) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 4874ea3d2795f..a997b5beadd34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, OneRowRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{DropVariable, LeafNode, OneRowRelation, Project} import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} /** * Unit tests for execution nodes from SqlScriptingExecutionNode.scala. @@ -82,9 +83,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } case class TestRepeat( - condition: TestLoopCondition, - body: CompoundBodyExec, - label: Option[String] = None) + condition: TestLoopCondition, + body: CompoundBodyExec, + label: Option[String] = None) extends RepeatStatementExec(condition, body, label, spark) { private val evaluator = new LoopBooleanConditionEvaluator(condition) @@ -94,6 +95,23 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi statement: LeafStatementExec): Boolean = evaluator.evaluateLoopBooleanCondition() } + case class MockQuery(numberOfRows: Int, columnName: String, description: String) + extends SingleStatementExec( + DummyLogicalPlan(), + Origin(startIndex = Some(0), stopIndex = Some(description.length)), + Map.empty, + isInternal = false) { + override def buildDataFrame(session: SparkSession): DataFrame = { + val data = Seq.range(0, numberOfRows).map(Row(_)) + val schema = List(StructField(columnName, IntegerType)) + + spark.createDataFrame( + spark.sparkContext.parallelize(data), + StructType(schema) + ) + } + } + private def extractStatementValue(statement: CompoundStatementExec): String = statement match { case TestLeafStatement(testVal) => testVal @@ -102,6 +120,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case loopStmt: LoopStatementExec => loopStmt.label.get case leaveStmt: LeaveStatementExec => leaveStmt.label case iterateStmt: IterateStatementExec => iterateStmt.label + case forStmt: ForStatementExec => forStmt.label.get + case dropStmt: SingleStatementExec if dropStmt.parsedPlan.isInstanceOf[DropVariable] + => "DropVariable" case _ => fail("Unexpected statement type") } @@ -688,4 +709,362 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq("body1", "lbl")) } + + test("for statement - enters body once") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(1, "intCol", "query1"), + variableName = Some("x"), + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "body", + "DropVariable", // drop for query var intCol + "DropVariable" // drop for loop var x + )) + } + + test("for statement - enters body with multiple statements multiple times") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("for1"), + session = spark, + body = new CompoundBodyExec( + Seq(TestLeafStatement("statement1"), TestLeafStatement("statement2")) + ) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "statement1", + "statement2", + "statement1", + "statement2", + "DropVariable", // drop for query var intCol + "DropVariable" // drop for loop var x + )) + } + + test("for statement - empty result") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(0, "intCol", "query1"), + variableName = Some("x"), + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq.empty[String]) + } + + test("for statement - nested") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol1", "query2"), + variableName = Some("y"), + label = Some("for2"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "body", + "body", + "DropVariable", // drop for query var intCol1 + "DropVariable", // drop for loop var y + "body", + "body", + "DropVariable", // drop for query var intCol1 + "DropVariable", // drop for loop var y + "DropVariable", // drop for query var intCol + "DropVariable" // drop for loop var x + )) + } + + test("for statement no variable - enters body once") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(1, "intCol", "query1"), + variableName = None, + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "body", + "DropVariable" // drop for query var intCol + )) + } + + test("for statement no variable - enters body with multiple statements multiple times") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + TestLeafStatement("statement2"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "statement1", "statement2", "statement1", "statement2", + "DropVariable" // drop for query var intCol + )) + } + + test("for statement no variable - empty result") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(0, "intCol", "query1"), + variableName = None, + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq.empty[String]) + } + + test("for statement no variable - nested") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol1", "query2"), + variableName = None, + label = Some("for2"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "body", "body", + "DropVariable", // drop for query var intCol1 + "body", "body", + "DropVariable", // drop for query var intCol1 + "DropVariable" // drop for query var intCol + )) + } + + test("for statement - iterate") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("statement2"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "statement1", + "lbl1", + "statement1", + "lbl1", + "DropVariable", // drop for query var intCol + "DropVariable" // drop for loop var x + )) + } + + test("for statement - leave") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("statement2"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("statement1", "lbl1")) + } + + test("for statement - nested - iterate outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("outer_body"), + new ForStatementExec( + query = MockQuery(2, "intCol1", "query2"), + variableName = Some("y"), + label = Some("lbl2"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("body2"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "outer_body", + "body1", + "lbl1", + "outer_body", + "body1", + "lbl1", + "DropVariable", // drop for query var intCol + "DropVariable" // drop for loop var x + )) + } + + test("for statement - nested - leave outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query2"), + variableName = Some("y"), + label = Some("lbl2"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("body2"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "lbl1")) + } + + test("for statement no variable - iterate") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("statement2"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "statement1", "lbl1", "statement1", "lbl1", + "DropVariable" // drop for query var intCol + )) + } + + test("for statement no variable - leave") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("statement2"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("statement1", "lbl1")) + } + + test("for statement no variable - nested - iterate outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("outer_body"), + new ForStatementExec( + query = MockQuery(2, "intCol1", "query2"), + variableName = None, + label = Some("lbl2"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("body2"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "outer_body", "body1", "lbl1", "outer_body", "body1", "lbl1", + "DropVariable" // drop for query var intCol + )) + } + + test("for statement no variable - nested - leave outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol1", "query2"), + variableName = None, + label = Some("lbl2"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("body2"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "lbl1")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 177ffc24d180a..71556c5502225 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -1550,4 +1550,1058 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { ) verifySqlScriptResult(sqlScriptText, expected) } + + test("for statement - enters body once") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | FOR row AS SELECT * FROM t DO + | SELECT row.intCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(1)), // select row.intCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - enters body with multiple statements multiple times") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | INSERT INTO t VALUES (2, 'second', 2.0); + | FOR row AS SELECT * FROM t ORDER BY intCol DO + | SELECT row.intCol; + | SELECT intCol; + | SELECT row.stringCol; + | SELECT stringCol; + | SELECT row.doubleCol; + | SELECT doubleCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(1)), // select row.intCol + Seq(Row(1)), // select intCol + Seq(Row("first")), // select row.stringCol + Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select row.doubleCol + Seq(Row(1.0)), // select doubleCol + Seq(Row(2)), // select row.intCol + Seq(Row(2)), // select intCol + Seq(Row("second")), // select row.stringCol + Seq(Row("second")), // select stringCol + Seq(Row(2.0)), // select row.doubleCol + Seq(Row(2.0)), // select doubleCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - sum of column from table") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE sumOfCols = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (1), (2), (3), (4); + | FOR row AS SELECT * FROM t DO + | SET sumOfCols = sumOfCols + row.intCol; + | END FOR; + | SELECT sumOfCols; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare sumOfCols + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq(Row(10)), // select sumOfCols + Seq.empty[Row] // drop sumOfCols + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - map, struct, array") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP, + | struct_column STRUCT, array_column ARRAY); + | INSERT INTO t VALUES + | (1, MAP('a', 1), STRUCT('John', 25), ARRAY('apricot', 'quince')), + | (2, MAP('b', 2), STRUCT('Jane', 30), ARRAY('plum', 'pear')); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.map_column; + | SELECT map_column; + | SELECT row.struct_column; + | SELECT struct_column; + | SELECT row.array_column; + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Map("a" -> 1))), // select row.map_column + Seq(Row(Map("a" -> 1))), // select map_column + Seq(Row(Row("John", 25))), // select row.struct_column + Seq(Row(Row("John", 25))), // select struct_column + Seq(Row(Array("apricot", "quince"))), // select row.array_column + Seq(Row(Array("apricot", "quince"))), // select array_column + Seq(Row(Map("b" -> 2))), // select row.map_column + Seq(Row(Map("b" -> 2))), // select map_column + Seq(Row(Row("Jane", 30))), // select row.struct_column + Seq(Row(Row("Jane", 30))), // select struct_column + Seq(Row(Array("plum", "pear"))), // select row.array_column + Seq(Row(Array("plum", "pear"))), // select array_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested struct") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t + | (int_column INT, + | struct_column STRUCT>>); + | INSERT INTO t VALUES + | (1, STRUCT(1, STRUCT(STRUCT("one")))), + | (2, STRUCT(2, STRUCT(STRUCT("two")))); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.struct_column; + | SELECT struct_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Row(1, Row(Row("one"))))), // select row.struct_column + Seq(Row(Row(1, Row(Row("one"))))), // select struct_column + Seq(Row(Row(2, Row(Row("two"))))), // select row.struct_column + Seq(Row(Row(2, Row(Row("two"))))), // select struct_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested map") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP>>); + | INSERT INTO t VALUES + | (1, MAP('a', MAP(1, MAP(false, 10)))), + | (2, MAP('b', MAP(2, MAP(true, 20)))); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.map_column; + | SELECT map_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select row.map_column + Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select map_column + Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select row.map_column + Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select map_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested array") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t + | (int_column INT, array_column ARRAY>>); + | INSERT INTO t VALUES + | (1, ARRAY(ARRAY(ARRAY(1, 2), ARRAY(3, 4)), ARRAY(ARRAY(5, 6)))), + | (2, ARRAY(ARRAY(ARRAY(7, 8), ARRAY(9, 10)), ARRAY(ARRAY(11, 12)))); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.array_column; + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // row.array_column + Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // array_column + Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // row.array_column + Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // array_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement empty result") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | FOR row AS SELECT * FROM t ORDER BY intCol DO + | SELECT row.intCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row] // create table + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement iterate") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); + | + | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO + | IF x.intCol = 2 THEN + | ITERATE lbl; + | END IF; + | SELECT stringCol; + | SELECT x.stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row("first")), // select stringCol + Seq(Row("first")), // select x.stringCol + Seq(Row("third")), // select stringCol + Seq(Row("third")), // select x.stringCol + Seq(Row("fourth")), // select stringCol + Seq(Row("fourth")), // select x.stringCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement leave") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); + | + | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO + | IF x.intCol = 3 THEN + | LEAVE lbl; + | END IF; + | SELECT stringCol; + | SELECT x.stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row("first")), // select stringCol + Seq(Row("first")), // select x.stringCol + Seq(Row("second")), // select stringCol + Seq(Row("second")) // select x.stringCol + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - in while") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE cnt = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (0); + | WHILE cnt < 2 DO + | SET cnt = cnt + 1; + | FOR x AS SELECT * FROM t ORDER BY intCol DO + | SELECT x.intCol; + | END FOR; + | INSERT INTO t VALUES (cnt); + | END WHILE; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare cnt + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq(Row(1)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - in other for") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | FOR x as SELECT * FROM t ORDER BY intCol DO + | FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT x.intCol; + | SELECT intCol; + | SELECT y.intCol2; + | SELECT intCol2; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(0)), // select x.intCol + Seq(Row(0)), // select intCol + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(0)), // select x.intCol + Seq(Row(0)), // select intCol + Seq(Row(2)), // select y.intCol2 + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq(Row(1)), // select x.intCol + Seq(Row(1)), // select intCol + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(1)), // select x.intCol + Seq(Row(1)), // select intCol + Seq(Row(2)), // select y.intCol2 + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop outer var + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + // ignored until loops are fixed to support empty bodies + ignore("for statement - nested - empty result set") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | REPEAT + | FOR x AS SELECT * FROM t ORDER BY intCol DO + | SELECT x.intCol; + | END FOR; + | UNTIL 1 = 1 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare cnt + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq(Row(1)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - iterate outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT y.intCol2; + | SELECT intCol2; + | ITERATE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row], // drop outer var + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - leave outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT y.intCol2; + | SELECT intCol2; + | LEAVE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)) // select intCol2 + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - leave inner loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT y.intCol2; + | SELECT intCol2; + | LEAVE lbl2; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row], // drop outer var + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - enters body once") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | FOR SELECT * FROM t DO + | SELECT intCol; + | SELECT stringCol; + | SELECT doubleCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(1)), // select intCol + Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select doubleCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - enters body with multiple statements multiple times") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | INSERT INTO t VALUES (2, 'second', 2.0); + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | SELECT stringCol; + | SELECT doubleCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(1)), // select intCol + Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select doubleCol + Seq(Row(2)), // select intCol + Seq(Row("second")), // select stringCol + Seq(Row(2.0)), // select doubleCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - sum of column from table") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE sumOfCols = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (1), (2), (3), (4); + | FOR SELECT * FROM t DO + | SET sumOfCols = sumOfCols + intCol; + | END FOR; + | SELECT sumOfCols; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare sumOfCols + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // drop local var + Seq(Row(10)), // select sumOfCols + Seq.empty[Row] // drop sumOfCols + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - map, struct, array") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP, + | struct_column STRUCT, array_column ARRAY); + | INSERT INTO t VALUES + | (1, MAP('a', 1), STRUCT('John', 25), ARRAY('apricot', 'quince')), + | (2, MAP('b', 2), STRUCT('Jane', 30), ARRAY('plum', 'pear')); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT map_column; + | SELECT struct_column; + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Map("a" -> 1))), // select map_column + Seq(Row(Row("John", 25))), // select struct_column + Seq(Row(Array("apricot", "quince"))), // select array_column + Seq(Row(Map("b" -> 2))), // select map_column + Seq(Row(Row("Jane", 30))), // select struct_column + Seq(Row(Array("plum", "pear"))), // select array_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested struct") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, + | struct_column STRUCT>>); + | INSERT INTO t VALUES + | (1, STRUCT(1, STRUCT(STRUCT("one")))), + | (2, STRUCT(2, STRUCT(STRUCT("two")))); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT struct_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Row(1, Row(Row("one"))))), // select struct_column + Seq(Row(Row(2, Row(Row("two"))))), // select struct_column + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested map") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP>>); + | INSERT INTO t VALUES + | (1, MAP('a', MAP(1, MAP(false, 10)))), + | (2, MAP('b', MAP(2, MAP(true, 20)))); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT map_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select map_column + Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select map_column + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested array") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t + | (int_column INT, array_column ARRAY>>); + | INSERT INTO t VALUES + | (1, ARRAY(ARRAY(ARRAY(1, 2), ARRAY(3, 4)), ARRAY(ARRAY(5, 6)))), + | (2, ARRAY(ARRAY(ARRAY(7, 8), ARRAY(9, 10)), ARRAY(ARRAY(11, 12)))); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // array_column + Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // array_column + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - empty result") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row] // create table + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - iterate") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); + | + | lbl: FOR SELECT * FROM t ORDER BY intCol DO + | IF intCol = 2 THEN + | ITERATE lbl; + | END IF; + | SELECT stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row("first")), // select stringCol + Seq(Row("third")), // select stringCol + Seq(Row("fourth")), // select stringCol + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - leave") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); + | + | lbl: FOR SELECT * FROM t ORDER BY intCol DO + | IF intCol = 3 THEN + | LEAVE lbl; + | END IF; + | SELECT stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row("first")), // select stringCol + Seq(Row("second")) // select stringCol + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - in while") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE cnt = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (0); + | WHILE cnt < 2 DO + | SET cnt = cnt + 1; + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | END FOR; + | INSERT INTO t VALUES (cnt); + | END WHILE; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare cnt + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq(Row(1)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - in other for") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | FOR SELECT * FROM t ORDER BY intCol DO + | FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol; + | SELECT intCol2; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(0)), // select intCol + Seq(Row(3)), // select intCol2 + Seq(Row(0)), // select intCol + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq(Row(1)), // select intCol + Seq(Row(3)), // select intCol2 + Seq(Row(1)), // select intCol + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + // ignored until loops are fixed to support empty bodies + ignore("for statement - no variable - nested - empty result set") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | REPEAT + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | END FOR; + | UNTIL 1 = 1 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare cnt + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq(Row(1)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - iterate outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol2; + | ITERATE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - leave outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol2; + | LEAVE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)) // select intCol2 + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - leave inner loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol2; + | LEAVE lbl2; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } }