Skip to content

Commit

Permalink
[SPARK-48346][SQL] Support for IF ELSE statements in SQL scripts
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR proposes introduction of IF/ELSE statement to SQL scripting language.
To evaluate conditions in IF or ELSE IF clauses, introduction of boolean statement evaluator is required as well.

Changes summary:
- Grammar/parser changes:
  - `ifElseStatement` grammar rule
  - `visitIfElseStatement` rule visitor
  - `IfElseStatement` logical operator
- `IfElseStatementExec` execution node:
  - Internal states - `Condition` and `Body`
  - Iterator implementation - iterate over conditions until the one that evaluates to `true` is found
  - Use `StatementBooleanEvaluator` implementation to evaluate conditions
- `DataFrameEvaluator`:
  - Implementation of `StatementBooleanEvaluator`
  - Evaluates results to `true` if it is single row, single column of boolean type with value `true`
- `SqlScriptingInterpreter` - add logic to transform `IfElseStatement` to `IfElseStatementExec`

### Why are the changes needed?
We are gradually introducing SQL Scripting to Spark, and IF/ELSE is one of the basic control flow constructs in the SQL language. For more details, check [JIRA item](https://issues.apache.org/jira/browse/SPARK-48346).

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
New tests are introduced to all of the three scripting test suites: `SqlScriptingParserSuite`, `SqlScriptingExecutionNodeSuite` and `SqlScriptingInterpreterSuite`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#47442 from davidm-db/sql_scripting_if_else.

Authored-by: David Milicevic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
davidm-db authored and cloud-fan committed Aug 5, 2024
1 parent f01eafd commit f99291a
Show file tree
Hide file tree
Showing 8 changed files with 663 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ compoundStatement
: statement
| setStatementWithOptionalVarKeyword
| beginEndCompoundBlock
| ifElseStatement
;

setStatementWithOptionalVarKeyword
Expand All @@ -71,6 +72,12 @@ setStatementWithOptionalVarKeyword
LEFT_PAREN query RIGHT_PAREN #setVariableWithOptionalKeyword
;

ifElseStatement
: IF booleanExpression THEN conditionalBodies+=compoundBody
(ELSE IF booleanExpression THEN conditionalBodies+=compoundBody)*
(ELSE elseBody=compoundBody)? END IF
;

singleStatement
: (statement|setResetStatement) SEMICOLON* EOF
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,23 @@ class AstBuilder extends DataTypeAstBuilder
.map { s =>
SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan])
}.getOrElse {
visit(ctx.beginEndCompoundBlock()).asInstanceOf[CompoundPlanStatement]
visitChildren(ctx).asInstanceOf[CompoundPlanStatement]
}
}

override def visitIfElseStatement(ctx: IfElseStatementContext): IfElseStatement = {
IfElseStatement(
conditions = ctx.booleanExpression().asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))
}),
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body)),
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body))
)
}

override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
Option(ctx.statement().asInstanceOf[ParserRuleContext])
.orElse(Option(ctx.setResetStatement().asInstanceOf[ParserRuleContext]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,19 @@ case class SingleStatement(parsedPlan: LogicalPlan)
case class CompoundBody(
collection: Seq[CompoundPlanStatement],
label: Option[String]) extends CompoundPlanStatement

/**
* Logical operator for IF ELSE statement.
* @param conditions Collection of conditions. First condition corresponds to IF clause,
* while others (if any) correspond to following ELSE IF clauses.
* @param conditionalBodies Collection of bodies that have a corresponding condition,
* in IF or ELSE IF branches.
* @param elseBody Body that is executed if none of the conditions are met,
* i.e. ELSE branch.
*/
case class IfElseStatement(
conditions: Seq[SingleStatement],
conditionalBodies: Seq[CompoundBody],
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
assert(conditions.length == conditionalBodies.length)
}
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,184 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
assert(e.getMessage.contains("Syntax error"))
}

test("if") {
val sqlScriptText =
"""
|BEGIN
| IF 1=1 THEN
| SELECT 42;
| END IF;
|END
|""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[IfElseStatement])
val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement]
assert(ifStmt.conditions.length == 1)
assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditions.head.getText == "1=1")
}

test("if else") {
val sqlScriptText =
"""BEGIN
|IF 1 = 1 THEN
| SELECT 1;
|ELSE
| SELECT 2;
|END IF;
|END
""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[IfElseStatement])

val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement]
assert(ifStmt.conditions.length == 1)
assert(ifStmt.conditionalBodies.length == 1)
assert(ifStmt.elseBody.isDefined)

assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditions.head.getText == "1 = 1")

assert(ifStmt.conditionalBodies.head.collection.length == 1)
assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 1")

assert(ifStmt.elseBody.get.collection.length == 1)
assert(ifStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 2")
}

test("if else if") {
val sqlScriptText =
"""BEGIN
|IF 1 = 1 THEN
| SELECT 1;
|ELSE IF 2 = 2 THEN
| SELECT 2;
|ELSE
| SELECT 3;
|END IF;
|END
""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[IfElseStatement])

val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement]
assert(ifStmt.conditions.length == 2)
assert(ifStmt.conditionalBodies.length == 2)
assert(ifStmt.elseBody.isDefined)

assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditions.head.getText == "1 = 1")

assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 1")

assert(ifStmt.conditions(1).isInstanceOf[SingleStatement])
assert(ifStmt.conditions(1).getText == "2 = 2")

assert(ifStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 2")

assert(ifStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 3")
}

test("if multi else if") {
val sqlScriptText =
"""BEGIN
|IF 1 = 1 THEN
| SELECT 1;
|ELSE IF 2 = 2 THEN
| SELECT 2;
|ELSE IF 3 = 3 THEN
| SELECT 3;
|END IF;
|END
""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[IfElseStatement])

val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement]
assert(ifStmt.conditions.length == 3)
assert(ifStmt.conditionalBodies.length == 3)
assert(ifStmt.elseBody.isEmpty)

assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditions.head.getText == "1 = 1")

assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 1")

assert(ifStmt.conditions(1).isInstanceOf[SingleStatement])
assert(ifStmt.conditions(1).getText == "2 = 2")

assert(ifStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 2")

assert(ifStmt.conditions(2).isInstanceOf[SingleStatement])
assert(ifStmt.conditions(2).getText == "3 = 3")

assert(ifStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 3")
}

test("if nested") {
val sqlScriptText =
"""
|BEGIN
| IF 1=1 THEN
| IF 2=1 THEN
| SELECT 41;
| ELSE
| SELECT 42;
| END IF;
| END IF;
|END
|""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[IfElseStatement])

val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement]
assert(ifStmt.conditions.length == 1)
assert(ifStmt.conditionalBodies.length == 1)
assert(ifStmt.elseBody.isEmpty)

assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditions.head.getText == "1=1")

assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[IfElseStatement])
val nestedIfStmt = ifStmt.conditionalBodies.head.collection.head.asInstanceOf[IfElseStatement]

assert(nestedIfStmt.conditions.length == 1)
assert(nestedIfStmt.conditionalBodies.length == 1)
assert(nestedIfStmt.elseBody.isDefined)

assert(nestedIfStmt.conditions.head.isInstanceOf[SingleStatement])
assert(nestedIfStmt.conditions.head.getText == "2=1")

assert(nestedIfStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(nestedIfStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 41")

assert(nestedIfStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
assert(nestedIfStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 42")
}

// Helper methods
def cleanupStatementString(statementStr: String): String = {
statementStr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package org.apache.spark.sql.scripting

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
import org.apache.spark.sql.types.BooleanType

/**
* Trait for all SQL scripting execution nodes used during interpretation phase.
Expand Down Expand Up @@ -55,6 +57,33 @@ trait NonLeafStatementExec extends CompoundStatementExec {
* Tree iterator.
*/
def getTreeIterator: Iterator[CompoundStatementExec]

/**
* Evaluate the boolean condition represented by the statement.
* @param session SparkSession that SQL script is executed within.
* @param statement Statement representing the boolean condition to evaluate.
* @return Whether the condition evaluates to True.
*/
protected def evaluateBooleanCondition(
session: SparkSession,
statement: LeafStatementExec): Boolean = statement match {
case statement: SingleStatementExec =>
assert(!statement.isExecuted)
statement.isExecuted = true

// DataFrame evaluates to True if it is single row, single column
// of boolean type with value True.
val df = Dataset.ofRows(session, statement.parsedPlan)
df.schema.fields match {
case Array(field) if field.dataType == BooleanType =>
df.limit(2).collect() match {
case Array(row) => row.getBoolean(0)
case _ => false
}
case _ => false
}
case _ => false
}
}

/**
Expand Down Expand Up @@ -155,3 +184,79 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState
*/
class CompoundBodyExec(statements: Seq[CompoundStatementExec])
extends CompoundNestedStatementIteratorExec(statements)

/**
* Executable node for IfElseStatement.
* @param conditions Collection of executable conditions. First condition corresponds to IF clause,
* while others (if any) correspond to following ELSE IF clauses.
* @param conditionalBodies Collection of executable bodies that have a corresponding condition,
* in IF or ELSE IF branches.
* @param elseBody Body that is executed if none of the conditions are met,
* i.e. ELSE branch.
* @param session Spark session that SQL script is executed within.
*/
class IfElseStatementExec(
conditions: Seq[SingleStatementExec],
conditionalBodies: Seq[CompoundBodyExec],
elseBody: Option[CompoundBodyExec],
session: SparkSession) extends NonLeafStatementExec {
private object IfElseState extends Enumeration {
val Condition, Body = Value
}

private var state = IfElseState.Condition
private var curr: Option[CompoundStatementExec] = Some(conditions.head)

private var clauseIdx: Int = 0
private val conditionsCount = conditions.length
assert(conditionsCount == conditionalBodies.length)

private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
override def hasNext: Boolean = curr.nonEmpty

override def next(): CompoundStatementExec = state match {
case IfElseState.Condition =>
assert(curr.get.isInstanceOf[SingleStatementExec])
val condition = curr.get.asInstanceOf[SingleStatementExec]
if (evaluateBooleanCondition(session, condition)) {
state = IfElseState.Body
curr = Some(conditionalBodies(clauseIdx))
} else {
clauseIdx += 1
if (clauseIdx < conditionsCount) {
// There are ELSE IF clauses remaining.
state = IfElseState.Condition
curr = Some(conditions(clauseIdx))
} else if (elseBody.isDefined) {
// ELSE clause exists.
state = IfElseState.Body
curr = Some(elseBody.get)
} else {
// No remaining clauses.
curr = None
}
}
condition
case IfElseState.Body =>
assert(curr.get.isInstanceOf[CompoundBodyExec])
val currBody = curr.get.asInstanceOf[CompoundBodyExec]
val retStmt = currBody.getTreeIterator.next()
if (!currBody.getTreeIterator.hasNext) {
curr = None
}
retStmt
}
}

override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator

override def reset(): Unit = {
state = IfElseState.Condition
curr = Some(conditions.head)
clauseIdx = 0
conditions.foreach(c => c.reset())
conditionalBodies.foreach(b => b.reset())
elseBody.foreach(b => b.reset())
}
}
Loading

0 comments on commit f99291a

Please sign in to comment.