diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 2e3235d6f932c..4b7b4634b74b2 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -70,6 +70,7 @@ compoundStatement | leaveStatement | iterateStatement | loopStatement + | forStatement ; setStatementWithOptionalVarKeyword @@ -111,6 +112,10 @@ loopStatement : beginLabel? LOOP compoundBody END LOOP endLabel? ; +forStatement + : beginLabel? FOR (multipartIdentifier AS)? query DO compoundBody END FOR endLabel? + ; + singleStatement : (statement|setResetStatement) SEMICOLON* EOF ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 6a9a97d0f5c8c..d558689a5c196 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -226,6 +226,8 @@ class AstBuilder extends DataTypeAstBuilder visitSearchedCaseStatementImpl(searchedCaseContext, labelCtx) case simpleCaseContext: SimpleCaseStatementContext => visitSimpleCaseStatementImpl(simpleCaseContext, labelCtx) + case forStatementContext: ForStatementContext => + visitForStatementImpl(forStatementContext, labelCtx) case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement] } } else { @@ -347,28 +349,48 @@ class AstBuilder extends DataTypeAstBuilder RepeatStatement(condition, body, Some(labelText)) } + private def visitForStatementImpl( + ctx: ForStatementContext, + labelCtx: SqlScriptingLabelContext): ForStatement = { + val labelText = labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel())) + + val queryCtx = ctx.query() + val query = withOrigin(queryCtx) { + SingleStatement(visitQuery(queryCtx)) + } + val varName = Option(ctx.multipartIdentifier()).map(_.getText) + val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx) + labelCtx.exitLabeledScope(Option(ctx.beginLabel())) + + ForStatement(query, varName, body, Some(labelText)) + } + private def leaveOrIterateContextHasLabel( ctx: RuleContext, label: String, isIterate: Boolean): Boolean = { ctx match { case c: BeginEndCompoundBlockContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) => - if (isIterate) { + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => if (isIterate) { throw SqlScriptingErrors.invalidIterateLabelUsageForCompound(CurrentOrigin.get, label) } true case c: WhileStatementContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true case c: RepeatStatementContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true case c: LoopStatementContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true + case c: ForStatementContext + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true case _ => false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index e6018e5e57b9c..4faf1f5d26672 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -267,3 +267,31 @@ case class LoopStatement( LoopStatement(newChildren(0).asInstanceOf[CompoundBody], label) } } + +/** + * Logical operator for FOR statement. + * @param query Query which is executed once, then it's result set is iterated on, row by row. + * @param variableName Name of variable which is used to access the current row during iteration. + * @param body Compound body is a collection of statements that are executed for each row in + * the result set of the query. + * @param label An optional label for the loop which is unique amongst all labels for statements + * within which the FOR statement is contained. + * If an end label is specified it must match the beginning label. + * The label can be used to LEAVE or ITERATE the loop. + */ +case class ForStatement( + query: SingleStatement, + variableName: Option[String], + body: CompoundBody, + label: Option[String]) extends CompoundPlanStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq(query, body) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = newChildren match { + case IndexedSeq(query: SingleStatement, body: CompoundBody) => + ForStatement(query, variableName, body, label) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 3bb84f603dc67..ab647f83b42a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf @@ -1176,7 +1176,6 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { head.asInstanceOf[SingleStatement].getText == "SELECT 42") assert(whileStmt.label.contains("lbl")) - } test("searched case statement") { @@ -1823,6 +1822,25 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { parameters = Map("label" -> toSQLId("l_loop"))) } + test("unique label names: nested for loops") { + val sqlScriptText = + """BEGIN + |f_loop: FOR x AS SELECT 1 DO + | f_loop: FOR y AS SELECT 2 DO + | SELECT 1; + | END FOR; + |END FOR; + |END + """.stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + } + checkError( + exception = exception, + condition = "LABEL_ALREADY_EXISTS", + parameters = Map("label" -> toSQLId("f_loop"))) + } + test("unique label names: begin-end block on the same level") { val sqlScriptText = """BEGIN @@ -1858,10 +1876,13 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT 4; |UNTIL 1=1 |END REPEAT; + |lbl: FOR x AS SELECT 1 DO + | SELECT 5; + |END FOR; |END """.stripMargin val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] - assert(tree.collection.length == 4) + assert(tree.collection.length == 5) assert(tree.collection.head.isInstanceOf[CompoundBody]) assert(tree.collection.head.asInstanceOf[CompoundBody].label.get == "lbl") assert(tree.collection(1).isInstanceOf[WhileStatement]) @@ -1870,6 +1891,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection(2).asInstanceOf[LoopStatement].label.get == "lbl") assert(tree.collection(3).isInstanceOf[RepeatStatement]) assert(tree.collection(3).asInstanceOf[RepeatStatement].label.get == "lbl") + assert(tree.collection(4).isInstanceOf[ForStatement]) + assert(tree.collection(4).asInstanceOf[ForStatement].label.get == "lbl") } test("unique label names: nested labeled scope statements") { @@ -1879,7 +1902,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | lbl_1: WHILE 1=1 DO | lbl_2: LOOP | lbl_3: REPEAT - | SELECT 4; + | lbl_4: FOR x AS SELECT 1 DO + | SELECT 4; + | END FOR; | UNTIL 1=1 | END REPEAT; | END LOOP; @@ -1905,6 +1930,241 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { // Repeat statement val repeatStatement = loopStatement.body.collection.head.asInstanceOf[RepeatStatement] assert(repeatStatement.label.get == "lbl_3") + // For statement + val forStatement = repeatStatement.body.collection.head.asInstanceOf[ForStatement] + assert(forStatement.label.get == "lbl_4") + } + + test("for statement") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR x AS SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.contains("x")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(forStmt.label.contains("lbl")) + } + + test("for statement - no label") { + val sqlScriptText = + """ + |BEGIN + | FOR x AS SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.contains("x")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + // when not explicitly set, label is random UUID + assert(forStmt.label.isDefined) + } + + test("for statement - with complex subquery") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR x AS SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1 DO + | SELECT x.c1; + | SELECT x.c2; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1") + assert(forStmt.variableName.contains("x")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 2) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT x.c1") + assert(forStmt.body.collection(1).isInstanceOf[SingleStatement]) + assert(forStmt.body.collection(1).asInstanceOf[SingleStatement].getText == "SELECT x.c2") + + assert(forStmt.label.contains("lbl")) + } + + test("for statement - nested") { + val sqlScriptText = + """ + |BEGIN + | lbl1: FOR i AS SELECT 1 DO + | lbl2: FOR j AS SELECT 2 DO + | SELECT i + j; + | END FOR lbl2; + | END FOR lbl1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 1") + assert(forStmt.variableName.contains("i")) + assert(forStmt.label.contains("lbl1")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[ForStatement]) + val nestedForStmt = forStmt.body.collection.head.asInstanceOf[ForStatement] + + assert(nestedForStmt.query.isInstanceOf[SingleStatement]) + assert(nestedForStmt.query.getText == "SELECT 2") + assert(nestedForStmt.variableName.contains("j")) + assert(nestedForStmt.label.contains("lbl2")) + + assert(nestedForStmt.body.isInstanceOf[CompoundBody]) + assert(nestedForStmt.body.collection.length == 1) + assert(nestedForStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedForStmt.body.collection. + head.asInstanceOf[SingleStatement].getText == "SELECT i + j") + } + + test("for statement - no variable") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.isEmpty) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(forStmt.label.contains("lbl")) + } + + test("for statement - no variable - no label") { + val sqlScriptText = + """ + |BEGIN + | FOR SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.isEmpty) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + // when not explicitly set, label is random UUID + assert(forStmt.label.isDefined) + } + + test("for statement - no variable - with complex subquery") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1 DO + | SELECT 1; + | SELECT 2; + | END FOR; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1") + assert(forStmt.variableName.isEmpty) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 2) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + assert(forStmt.body.collection(1).isInstanceOf[SingleStatement]) + assert(forStmt.body.collection(1).asInstanceOf[SingleStatement].getText == "SELECT 2") + + assert(forStmt.label.contains("lbl")) + } + + test("for statement - no variable - nested") { + val sqlScriptText = + """ + |BEGIN + | lbl1: FOR SELECT 1 DO + | lbl2: FOR SELECT 2 DO + | SELECT 3; + | END FOR lbl2; + | END FOR lbl1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 1") + assert(forStmt.variableName.isEmpty) + assert(forStmt.label.contains("lbl1")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[ForStatement]) + val nestedForStmt = forStmt.body.collection.head.asInstanceOf[ForStatement] + + assert(nestedForStmt.query.isInstanceOf[SingleStatement]) + assert(nestedForStmt.query.getText == "SELECT 2") + assert(nestedForStmt.variableName.isEmpty) + assert(nestedForStmt.label.contains("lbl2")) + + assert(nestedForStmt.body.isInstanceOf[CompoundBody]) + assert(nestedForStmt.body.collection.length == 1) + assert(nestedForStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedForStmt.body.collection. + head.asInstanceOf[SingleStatement].getText == "SELECT 3") } // Helper methods diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 94284ec514f55..e3559e8f18ae2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.scripting +import java.util + import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.catalyst.analysis.NameParameterizedQuery -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier} +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType @@ -662,3 +664,222 @@ class LoopStatementExec( body.reset() } } + +/** + * Executable node for ForStatement. + * @param query Executable node for the query. + * @param variableName Name of variable used for accessing current row during iteration. + * @param body Executable node for the body. + * @param label Label set to ForStatement by user or None otherwise. + * @param session Spark session that SQL script is executed within. + */ +class ForStatementExec( + query: SingleStatementExec, + variableName: Option[String], + body: CompoundBodyExec, + val label: Option[String], + session: SparkSession) extends NonLeafStatementExec { + + private object ForState extends Enumeration { + val VariableAssignment, Body, VariableCleanup = Value + } + private var state = ForState.VariableAssignment + private var areVariablesDeclared = false + + // map of all variables created internally by the for statement + // (variableName -> variableExpression) + private var variablesMap: Map[String, Expression] = Map() + + // compound body used for dropping variables while in ForState.VariableAssignment + private var dropVariablesExec: CompoundBodyExec = null + + private var queryResult: util.Iterator[Row] = _ + private var isResultCacheValid = false + private def cachedQueryResult(): util.Iterator[Row] = { + if (!isResultCacheValid) { + queryResult = query.buildDataFrame(session).toLocalIterator() + query.isExecuted = true + isResultCacheValid = true + } + queryResult + } + + /** + * For can be interrupted by LeaveStatementExec + */ + private var interrupted: Boolean = false + + private lazy val treeIterator: Iterator[CompoundStatementExec] = + new Iterator[CompoundStatementExec] { + + override def hasNext: Boolean = !interrupted && (state match { + case ForState.VariableAssignment => cachedQueryResult().hasNext + case ForState.Body => true + case ForState.VariableCleanup => dropVariablesExec.getTreeIterator.hasNext + }) + + override def next(): CompoundStatementExec = state match { + + case ForState.VariableAssignment => + variablesMap = createVariablesMapFromRow(cachedQueryResult().next()) + + if (!areVariablesDeclared) { + // create and execute declare var statements + variablesMap.keys.toSeq + .map(colName => createDeclareVarExec(colName, variablesMap(colName))) + .foreach(declareVarExec => declareVarExec.buildDataFrame(session).collect()) + areVariablesDeclared = true + } + + // create and execute set var statements + variablesMap.keys.toSeq + .map(colName => createSetVarExec(colName, variablesMap(colName))) + .foreach(setVarExec => setVarExec.buildDataFrame(session).collect()) + + state = ForState.Body + body.reset() + next() + + case ForState.Body => + val retStmt = body.getTreeIterator.next() + + // Handle LEAVE or ITERATE statement if it has been encountered. + retStmt match { + case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched => + if (label.contains(leaveStatementExec.label)) { + leaveStatementExec.hasBeenMatched = true + } + interrupted = true + // If this for statement encounters LEAVE, it will either not be executed + // again, or it will be reset before being executed. + // In either case, variables will not + // be dropped normally, from ForState.VariableCleanup, so we drop them here. + dropVars() + return retStmt + case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched => + if (label.contains(iterStatementExec.label)) { + iterStatementExec.hasBeenMatched = true + } else { + // if an outer loop is being iterated, this for statement will either not be + // executed again, or it will be reset before being executed. + // In either case, variables will not + // be dropped normally, from ForState.VariableCleanup, so we drop them here. + dropVars() + } + switchStateFromBody() + return retStmt + case _ => + } + + if (!body.getTreeIterator.hasNext) { + switchStateFromBody() + } + retStmt + + case ForState.VariableCleanup => + dropVariablesExec.getTreeIterator.next() + } + } + + /** + * Recursively creates a Catalyst expression from Scala value.
+ * See https://spark.apache.org/docs/latest/sql-ref-datatypes.html for Spark -> Scala mappings + */ + private def createExpressionFromValue(value: Any): Expression = value match { + case m: Map[_, _] => + // arguments of CreateMap are in the format: (key1, val1, key2, val2, ...) + val mapArgs = m.keys.toSeq.flatMap { key => + Seq(createExpressionFromValue(key), createExpressionFromValue(m(key))) + } + CreateMap(mapArgs, useStringTypeWhenEmpty = false) + + // structs and rows match this case + case s: Row => + // arguments of CreateNamedStruct are in the format: (name1, val1, name2, val2, ...) + val namedStructArgs = s.schema.names.toSeq.flatMap { colName => + val valueExpression = createExpressionFromValue(s.getAs(colName)) + Seq(Literal(colName), valueExpression) + } + CreateNamedStruct(namedStructArgs) + + // arrays match this case + case a: collection.Seq[_] => + val arrayArgs = a.toSeq.map(createExpressionFromValue(_)) + CreateArray(arrayArgs, useStringTypeWhenEmpty = false) + + case _ => Literal(value) + } + + private def createVariablesMapFromRow(row: Row): Map[String, Expression] = { + var variablesMap = row.schema.names.toSeq.map { colName => + colName -> createExpressionFromValue(row.getAs(colName)) + }.toMap + + if (variableName.isDefined) { + val namedStructArgs = variablesMap.keys.toSeq.flatMap { colName => + Seq(Literal(colName), variablesMap(colName)) + } + val forVariable = CreateNamedStruct(namedStructArgs) + variablesMap = variablesMap + (variableName.get -> forVariable) + } + variablesMap + } + + /** + * Create and immediately execute dropVariable exec nodes for all variables in variablesMap. + */ + private def dropVars(): Unit = { + variablesMap.keys.toSeq + .map(colName => createDropVarExec(colName)) + .foreach(dropVarExec => dropVarExec.buildDataFrame(session).collect()) + areVariablesDeclared = false + } + + private def switchStateFromBody(): Unit = { + state = if (cachedQueryResult().hasNext) ForState.VariableAssignment + else { + // create compound body for dropping nodes after execution is complete + dropVariablesExec = new CompoundBodyExec( + variablesMap.keys.toSeq.map(colName => createDropVarExec(colName)) + ) + ForState.VariableCleanup + } + } + + private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = { + val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null") + val declareVariable = CreateVariable( + UnresolvedIdentifier(Seq(varName)), + defaultExpression, + replace = true + ) + new SingleStatementExec(declareVariable, Origin(), Map.empty, isInternal = true) + } + + private def createSetVarExec(varName: String, variable: Expression): SingleStatementExec = { + val projectNamedStruct = Project( + Seq(Alias(variable, varName)()), + OneRowRelation() + ) + val setIdentifierToCurrentRow = + SetVariable(Seq(UnresolvedAttribute(varName)), projectNamedStruct) + new SingleStatementExec(setIdentifierToCurrentRow, Origin(), Map.empty, isInternal = true) + } + + private def createDropVarExec(varName: String): SingleStatementExec = { + val dropVar = DropVariable(UnresolvedIdentifier(Seq(varName)), ifExists = true) + new SingleStatementExec(dropVar, Origin(), Map.empty, isInternal = true) + } + + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator + + override def reset(): Unit = { + state = ForState.VariableAssignment + isResultCacheValid = false + variablesMap = Map() + areVariablesDeclared = false + dropVariablesExec = null + interrupted = false + body.reset() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 387ae36b881f4..a3dc3d4599314 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.catalyst.trees.Origin /** @@ -145,6 +145,17 @@ case class SqlScriptingInterpreter(session: SparkSession) { .asInstanceOf[CompoundBodyExec] new LoopStatementExec(bodyExec, label) + case ForStatement(query, variableNameOpt, body, label) => + val queryExec = + new SingleStatementExec( + query.parsedPlan, + query.origin, + args, + isInternal = false) + val bodyExec = + transformTreeIntoExecutable(body, args).asInstanceOf[CompoundBodyExec] + new ForStatementExec(queryExec, variableNameOpt, bodyExec, label, session) + case leaveStatement: LeaveStatement => new LeaveStatementExec(leaveStatement.label) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 4874ea3d2795f..a997b5beadd34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, OneRowRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{DropVariable, LeafNode, OneRowRelation, Project} import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} /** * Unit tests for execution nodes from SqlScriptingExecutionNode.scala. @@ -82,9 +83,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } case class TestRepeat( - condition: TestLoopCondition, - body: CompoundBodyExec, - label: Option[String] = None) + condition: TestLoopCondition, + body: CompoundBodyExec, + label: Option[String] = None) extends RepeatStatementExec(condition, body, label, spark) { private val evaluator = new LoopBooleanConditionEvaluator(condition) @@ -94,6 +95,23 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi statement: LeafStatementExec): Boolean = evaluator.evaluateLoopBooleanCondition() } + case class MockQuery(numberOfRows: Int, columnName: String, description: String) + extends SingleStatementExec( + DummyLogicalPlan(), + Origin(startIndex = Some(0), stopIndex = Some(description.length)), + Map.empty, + isInternal = false) { + override def buildDataFrame(session: SparkSession): DataFrame = { + val data = Seq.range(0, numberOfRows).map(Row(_)) + val schema = List(StructField(columnName, IntegerType)) + + spark.createDataFrame( + spark.sparkContext.parallelize(data), + StructType(schema) + ) + } + } + private def extractStatementValue(statement: CompoundStatementExec): String = statement match { case TestLeafStatement(testVal) => testVal @@ -102,6 +120,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case loopStmt: LoopStatementExec => loopStmt.label.get case leaveStmt: LeaveStatementExec => leaveStmt.label case iterateStmt: IterateStatementExec => iterateStmt.label + case forStmt: ForStatementExec => forStmt.label.get + case dropStmt: SingleStatementExec if dropStmt.parsedPlan.isInstanceOf[DropVariable] + => "DropVariable" case _ => fail("Unexpected statement type") } @@ -688,4 +709,362 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq("body1", "lbl")) } + + test("for statement - enters body once") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(1, "intCol", "query1"), + variableName = Some("x"), + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "body", + "DropVariable", // drop for query var intCol + "DropVariable" // drop for loop var x + )) + } + + test("for statement - enters body with multiple statements multiple times") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("for1"), + session = spark, + body = new CompoundBodyExec( + Seq(TestLeafStatement("statement1"), TestLeafStatement("statement2")) + ) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "statement1", + "statement2", + "statement1", + "statement2", + "DropVariable", // drop for query var intCol + "DropVariable" // drop for loop var x + )) + } + + test("for statement - empty result") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(0, "intCol", "query1"), + variableName = Some("x"), + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq.empty[String]) + } + + test("for statement - nested") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol1", "query2"), + variableName = Some("y"), + label = Some("for2"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "body", + "body", + "DropVariable", // drop for query var intCol1 + "DropVariable", // drop for loop var y + "body", + "body", + "DropVariable", // drop for query var intCol1 + "DropVariable", // drop for loop var y + "DropVariable", // drop for query var intCol + "DropVariable" // drop for loop var x + )) + } + + test("for statement no variable - enters body once") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(1, "intCol", "query1"), + variableName = None, + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "body", + "DropVariable" // drop for query var intCol + )) + } + + test("for statement no variable - enters body with multiple statements multiple times") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + TestLeafStatement("statement2"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "statement1", "statement2", "statement1", "statement2", + "DropVariable" // drop for query var intCol + )) + } + + test("for statement no variable - empty result") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(0, "intCol", "query1"), + variableName = None, + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq.empty[String]) + } + + test("for statement no variable - nested") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("for1"), + session = spark, + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol1", "query2"), + variableName = None, + label = Some("for2"), + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "body", "body", + "DropVariable", // drop for query var intCol1 + "body", "body", + "DropVariable", // drop for query var intCol1 + "DropVariable" // drop for query var intCol + )) + } + + test("for statement - iterate") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("statement2"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "statement1", + "lbl1", + "statement1", + "lbl1", + "DropVariable", // drop for query var intCol + "DropVariable" // drop for loop var x + )) + } + + test("for statement - leave") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("statement2"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("statement1", "lbl1")) + } + + test("for statement - nested - iterate outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("outer_body"), + new ForStatementExec( + query = MockQuery(2, "intCol1", "query2"), + variableName = Some("y"), + label = Some("lbl2"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("body2"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "outer_body", + "body1", + "lbl1", + "outer_body", + "body1", + "lbl1", + "DropVariable", // drop for query var intCol + "DropVariable" // drop for loop var x + )) + } + + test("for statement - nested - leave outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = Some("x"), + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query2"), + variableName = Some("y"), + label = Some("lbl2"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("body2"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "lbl1")) + } + + test("for statement no variable - iterate") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("statement2"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "statement1", "lbl1", "statement1", "lbl1", + "DropVariable" // drop for query var intCol + )) + } + + test("for statement no variable - leave") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("statement2"))) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("statement1", "lbl1")) + } + + test("for statement no variable - nested - iterate outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("outer_body"), + new ForStatementExec( + query = MockQuery(2, "intCol1", "query2"), + variableName = None, + label = Some("lbl2"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("body2"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "outer_body", "body1", "lbl1", "outer_body", "body1", "lbl1", + "DropVariable" // drop for query var intCol + )) + } + + test("for statement no variable - nested - leave outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol", "query1"), + variableName = None, + label = Some("lbl1"), + session = spark, + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = MockQuery(2, "intCol1", "query2"), + variableName = None, + label = Some("lbl2"), + session = spark, + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("body2"))) + ) + )) + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "lbl1")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 177ffc24d180a..71556c5502225 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -1550,4 +1550,1058 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { ) verifySqlScriptResult(sqlScriptText, expected) } + + test("for statement - enters body once") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | FOR row AS SELECT * FROM t DO + | SELECT row.intCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(1)), // select row.intCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - enters body with multiple statements multiple times") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | INSERT INTO t VALUES (2, 'second', 2.0); + | FOR row AS SELECT * FROM t ORDER BY intCol DO + | SELECT row.intCol; + | SELECT intCol; + | SELECT row.stringCol; + | SELECT stringCol; + | SELECT row.doubleCol; + | SELECT doubleCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(1)), // select row.intCol + Seq(Row(1)), // select intCol + Seq(Row("first")), // select row.stringCol + Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select row.doubleCol + Seq(Row(1.0)), // select doubleCol + Seq(Row(2)), // select row.intCol + Seq(Row(2)), // select intCol + Seq(Row("second")), // select row.stringCol + Seq(Row("second")), // select stringCol + Seq(Row(2.0)), // select row.doubleCol + Seq(Row(2.0)), // select doubleCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - sum of column from table") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE sumOfCols = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (1), (2), (3), (4); + | FOR row AS SELECT * FROM t DO + | SET sumOfCols = sumOfCols + row.intCol; + | END FOR; + | SELECT sumOfCols; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare sumOfCols + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq(Row(10)), // select sumOfCols + Seq.empty[Row] // drop sumOfCols + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - map, struct, array") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP, + | struct_column STRUCT, array_column ARRAY); + | INSERT INTO t VALUES + | (1, MAP('a', 1), STRUCT('John', 25), ARRAY('apricot', 'quince')), + | (2, MAP('b', 2), STRUCT('Jane', 30), ARRAY('plum', 'pear')); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.map_column; + | SELECT map_column; + | SELECT row.struct_column; + | SELECT struct_column; + | SELECT row.array_column; + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Map("a" -> 1))), // select row.map_column + Seq(Row(Map("a" -> 1))), // select map_column + Seq(Row(Row("John", 25))), // select row.struct_column + Seq(Row(Row("John", 25))), // select struct_column + Seq(Row(Array("apricot", "quince"))), // select row.array_column + Seq(Row(Array("apricot", "quince"))), // select array_column + Seq(Row(Map("b" -> 2))), // select row.map_column + Seq(Row(Map("b" -> 2))), // select map_column + Seq(Row(Row("Jane", 30))), // select row.struct_column + Seq(Row(Row("Jane", 30))), // select struct_column + Seq(Row(Array("plum", "pear"))), // select row.array_column + Seq(Row(Array("plum", "pear"))), // select array_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested struct") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t + | (int_column INT, + | struct_column STRUCT>>); + | INSERT INTO t VALUES + | (1, STRUCT(1, STRUCT(STRUCT("one")))), + | (2, STRUCT(2, STRUCT(STRUCT("two")))); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.struct_column; + | SELECT struct_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Row(1, Row(Row("one"))))), // select row.struct_column + Seq(Row(Row(1, Row(Row("one"))))), // select struct_column + Seq(Row(Row(2, Row(Row("two"))))), // select row.struct_column + Seq(Row(Row(2, Row(Row("two"))))), // select struct_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested map") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP>>); + | INSERT INTO t VALUES + | (1, MAP('a', MAP(1, MAP(false, 10)))), + | (2, MAP('b', MAP(2, MAP(true, 20)))); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.map_column; + | SELECT map_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select row.map_column + Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select map_column + Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select row.map_column + Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select map_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested array") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t + | (int_column INT, array_column ARRAY>>); + | INSERT INTO t VALUES + | (1, ARRAY(ARRAY(ARRAY(1, 2), ARRAY(3, 4)), ARRAY(ARRAY(5, 6)))), + | (2, ARRAY(ARRAY(ARRAY(7, 8), ARRAY(9, 10)), ARRAY(ARRAY(11, 12)))); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.array_column; + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // row.array_column + Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // array_column + Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // row.array_column + Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // array_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement empty result") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | FOR row AS SELECT * FROM t ORDER BY intCol DO + | SELECT row.intCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row] // create table + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement iterate") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); + | + | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO + | IF x.intCol = 2 THEN + | ITERATE lbl; + | END IF; + | SELECT stringCol; + | SELECT x.stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row("first")), // select stringCol + Seq(Row("first")), // select x.stringCol + Seq(Row("third")), // select stringCol + Seq(Row("third")), // select x.stringCol + Seq(Row("fourth")), // select stringCol + Seq(Row("fourth")), // select x.stringCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement leave") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); + | + | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO + | IF x.intCol = 3 THEN + | LEAVE lbl; + | END IF; + | SELECT stringCol; + | SELECT x.stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row("first")), // select stringCol + Seq(Row("first")), // select x.stringCol + Seq(Row("second")), // select stringCol + Seq(Row("second")) // select x.stringCol + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - in while") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE cnt = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (0); + | WHILE cnt < 2 DO + | SET cnt = cnt + 1; + | FOR x AS SELECT * FROM t ORDER BY intCol DO + | SELECT x.intCol; + | END FOR; + | INSERT INTO t VALUES (cnt); + | END WHILE; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare cnt + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq(Row(1)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - in other for") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | FOR x as SELECT * FROM t ORDER BY intCol DO + | FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT x.intCol; + | SELECT intCol; + | SELECT y.intCol2; + | SELECT intCol2; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(0)), // select x.intCol + Seq(Row(0)), // select intCol + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(0)), // select x.intCol + Seq(Row(0)), // select intCol + Seq(Row(2)), // select y.intCol2 + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq(Row(1)), // select x.intCol + Seq(Row(1)), // select intCol + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(1)), // select x.intCol + Seq(Row(1)), // select intCol + Seq(Row(2)), // select y.intCol2 + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop outer var + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + // ignored until loops are fixed to support empty bodies + ignore("for statement - nested - empty result set") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | REPEAT + | FOR x AS SELECT * FROM t ORDER BY intCol DO + | SELECT x.intCol; + | END FOR; + | UNTIL 1 = 1 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare cnt + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq(Row(1)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - iterate outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT y.intCol2; + | SELECT intCol2; + | ITERATE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row], // drop outer var + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - leave outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT y.intCol2; + | SELECT intCol2; + | LEAVE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)) // select intCol2 + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - leave inner loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT y.intCol2; + | SELECT intCol2; + | LEAVE lbl2; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row], // drop outer var + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - enters body once") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | FOR SELECT * FROM t DO + | SELECT intCol; + | SELECT stringCol; + | SELECT doubleCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(1)), // select intCol + Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select doubleCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - enters body with multiple statements multiple times") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | INSERT INTO t VALUES (2, 'second', 2.0); + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | SELECT stringCol; + | SELECT doubleCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(1)), // select intCol + Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select doubleCol + Seq(Row(2)), // select intCol + Seq(Row("second")), // select stringCol + Seq(Row(2.0)), // select doubleCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - sum of column from table") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE sumOfCols = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (1), (2), (3), (4); + | FOR SELECT * FROM t DO + | SET sumOfCols = sumOfCols + intCol; + | END FOR; + | SELECT sumOfCols; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare sumOfCols + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // drop local var + Seq(Row(10)), // select sumOfCols + Seq.empty[Row] // drop sumOfCols + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - map, struct, array") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP, + | struct_column STRUCT, array_column ARRAY); + | INSERT INTO t VALUES + | (1, MAP('a', 1), STRUCT('John', 25), ARRAY('apricot', 'quince')), + | (2, MAP('b', 2), STRUCT('Jane', 30), ARRAY('plum', 'pear')); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT map_column; + | SELECT struct_column; + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Map("a" -> 1))), // select map_column + Seq(Row(Row("John", 25))), // select struct_column + Seq(Row(Array("apricot", "quince"))), // select array_column + Seq(Row(Map("b" -> 2))), // select map_column + Seq(Row(Row("Jane", 30))), // select struct_column + Seq(Row(Array("plum", "pear"))), // select array_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested struct") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, + | struct_column STRUCT>>); + | INSERT INTO t VALUES + | (1, STRUCT(1, STRUCT(STRUCT("one")))), + | (2, STRUCT(2, STRUCT(STRUCT("two")))); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT struct_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Row(1, Row(Row("one"))))), // select struct_column + Seq(Row(Row(2, Row(Row("two"))))), // select struct_column + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested map") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP>>); + | INSERT INTO t VALUES + | (1, MAP('a', MAP(1, MAP(false, 10)))), + | (2, MAP('b', MAP(2, MAP(true, 20)))); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT map_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select map_column + Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select map_column + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested array") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t + | (int_column INT, array_column ARRAY>>); + | INSERT INTO t VALUES + | (1, ARRAY(ARRAY(ARRAY(1, 2), ARRAY(3, 4)), ARRAY(ARRAY(5, 6)))), + | (2, ARRAY(ARRAY(ARRAY(7, 8), ARRAY(9, 10)), ARRAY(ARRAY(11, 12)))); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // array_column + Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // array_column + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - empty result") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row] // create table + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - iterate") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); + | + | lbl: FOR SELECT * FROM t ORDER BY intCol DO + | IF intCol = 2 THEN + | ITERATE lbl; + | END IF; + | SELECT stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row("first")), // select stringCol + Seq(Row("third")), // select stringCol + Seq(Row("fourth")), // select stringCol + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - leave") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); + | + | lbl: FOR SELECT * FROM t ORDER BY intCol DO + | IF intCol = 3 THEN + | LEAVE lbl; + | END IF; + | SELECT stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row("first")), // select stringCol + Seq(Row("second")) // select stringCol + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - in while") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE cnt = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (0); + | WHILE cnt < 2 DO + | SET cnt = cnt + 1; + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | END FOR; + | INSERT INTO t VALUES (cnt); + | END WHILE; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare cnt + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq(Row(1)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - in other for") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | FOR SELECT * FROM t ORDER BY intCol DO + | FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol; + | SELECT intCol2; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(0)), // select intCol + Seq(Row(3)), // select intCol2 + Seq(Row(0)), // select intCol + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq(Row(1)), // select intCol + Seq(Row(3)), // select intCol2 + Seq(Row(1)), // select intCol + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + // ignored until loops are fixed to support empty bodies + ignore("for statement - no variable - nested - empty result set") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | REPEAT + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | END FOR; + | UNTIL 1 = 1 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare cnt + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq(Row(1)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - iterate outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol2; + | ITERATE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - leave outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol2; + | LEAVE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)) // select intCol2 + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - leave inner loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol2; + | LEAVE lbl2; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } }