Skip to content

Commit

Permalink
Add setStatementWithOptionalVarKeyword to the handler grammar and vis…
Browse files Browse the repository at this point in the history
…itor
  • Loading branch information
miland-db committed Aug 13, 2024
1 parent a1573f6 commit 7598e0f
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ declareCondition
;

declareHandler
: DECLARE (CONTINUE | EXIT) HANDLER FOR conditionValueList (BEGIN compoundBody END | statement)
: DECLARE (CONTINUE | EXIT) HANDLER FOR conditionValueList (BEGIN compoundBody END | statement | setStatementWithOptionalVarKeyword)
;

beginLabel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class AstBuilder extends DataTypeAstBuilder
val handlerType = Option(ctx.EXIT()).map(_ => HandlerType.EXIT).getOrElse(HandlerType.CONTINUE)

val body = Option(ctx.compoundBody()).map(visit).getOrElse {
val logicalPlan = visit(ctx.statement()).asInstanceOf[LogicalPlan]
val logicalPlan = visitChildren(ctx).asInstanceOf[LogicalPlan]
CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)))
}.asInstanceOf[CompoundBody]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,32 +268,32 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("while - doesn't enter body") {
val iter = new CompoundBodyExec(Seq(
val iter = TestBody(Seq(
TestWhile(
condition = TestWhileCondition(condVal = true, reps = 0, description = "con1"),
body = new CompoundBodyExec(Seq(TestLeafStatement("body1")))
body = TestBody(Seq(TestLeafStatement("body1")))
)
)).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("con1"))
}

test("while - enters body once") {
val iter = new CompoundBodyExec(Seq(
val iter = TestBody(Seq(
TestWhile(
condition = TestWhileCondition(condVal = true, reps = 1, description = "con1"),
body = new CompoundBodyExec(Seq(TestLeafStatement("body1")))
body = TestBody(Seq(TestLeafStatement("body1")))
)
)).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("con1", "body1", "con1"))
}

test("while - enters body with multiple statements multiple times") {
val iter = new CompoundBodyExec(Seq(
val iter = TestBody(Seq(
TestWhile(
condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"),
body = new CompoundBodyExec(Seq(
body = TestBody(Seq(
TestLeafStatement("statement1"),
TestLeafStatement("statement2")))
)
Expand All @@ -304,13 +304,13 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("nested while - 2 times outer 2 times inner") {
val iter = new CompoundBodyExec(Seq(
val iter = TestBody(Seq(
TestWhile(
condition = TestWhileCondition(condVal = true, reps = 2, description = "con1"),
body = new CompoundBodyExec(Seq(
body = TestBody(Seq(
TestWhile(
condition = TestWhileCondition(condVal = true, reps = 2, description = "con2"),
body = new CompoundBodyExec(Seq(TestLeafStatement("body1")))
body = TestBody(Seq(TestLeafStatement("body1")))
))
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -696,14 +696,14 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession
|""".stripMargin

val expected = Seq(
Seq.empty[Row], // declare i
Seq(Row(0)), // select i
Seq.empty[Row], // set i
Seq(Row(1)), // select i
Seq.empty[Row], // set i
Seq(Row(2)), // select i
Seq.empty[Row], // set i
Seq.empty[Row] // drop var
Array.empty[Row], // declare i
Array(Row(0)), // select i
Array.empty[Row], // set i
Array(Row(1)), // select i
Array.empty[Row], // set i
Array(Row(2)), // select i
Array.empty[Row], // set i
Array.empty[Row] // drop var
)
verifySqlScriptResult(commands, expected)
}
Expand All @@ -721,8 +721,8 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession
|""".stripMargin

val expected = Seq(
Seq.empty[Row], // declare i
Seq.empty[Row] // drop i
Array.empty[Row], // declare i
Array.empty[Row] // drop i
)
verifySqlScriptResult(commands, expected)
}
Expand All @@ -745,22 +745,22 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession
|""".stripMargin

val expected = Seq(
Seq.empty[Row], // declare i
Seq.empty[Row], // declare j
Seq.empty[Row], // set j to 0
Seq(Row(0, 0)), // select i, j
Seq.empty[Row], // increase j
Seq(Row(0, 1)), // select i, j
Seq.empty[Row], // increase j
Seq.empty[Row], // increase i
Seq.empty[Row], // set j to 0
Seq(Row(1, 0)), // select i, j
Seq.empty[Row], // increase j
Seq(Row(1, 1)), // select i, j
Seq.empty[Row], // increase j
Seq.empty[Row], // increase i
Seq.empty[Row], // drop j
Seq.empty[Row] // drop i
Array.empty[Row], // declare i
Array.empty[Row], // declare j
Array.empty[Row], // set j to 0
Array(Row(0, 0)), // select i, j
Array.empty[Row], // increase j
Array(Row(0, 1)), // select i, j
Array.empty[Row], // increase j
Array.empty[Row], // increase i
Array.empty[Row], // set j to 0
Array(Row(1, 0)), // select i, j
Array.empty[Row], // increase j
Array(Row(1, 1)), // select i, j
Array.empty[Row], // increase j
Array.empty[Row], // increase i
Array.empty[Row], // drop j
Array.empty[Row] // drop i
)
verifySqlScriptResult(commands, expected)
}
Expand All @@ -779,11 +779,11 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession
|""".stripMargin

val expected = Seq(
Seq.empty[Row], // create table
Seq(Row(42)), // select
Seq.empty[Row], // insert
Seq(Row(42)), // select
Seq.empty[Row] // insert
Array.empty[Row], // create table
Array(Row(42)), // select
Array.empty[Row], // insert
Array(Row(42)), // select
Array.empty[Row] // insert
)
verifySqlScriptResult(commands, expected)
}
Expand Down

0 comments on commit 7598e0f

Please sign in to comment.