Skip to content

Commit

Permalink
[SPARK-48355][SQL] Support for CASE statement
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add support for [case statements](https://docs.google.com/document/d/1cpSuR3KxRuTSJ4ZMQ73FJ4_-hjouNNU2zfI4vri6yhs/edit#heading=h.ofijhkunigv) to sql scripting. There are 2 types of case statement - simple and searched (EXAMPLES BELOW). Proposed changes are:

- Add `caseStatement` grammar rule to SqlBaseParser.g4
- Add visit case statement methods to `AstBuilder`
- Add `SearchedCaseStatement` and `SearchedCaseStatementExec` classes, to enable them to be run in sql scripts.

The reason only searched case nodes are added is that, in the current implementation, a simple case is parsed into a searched case, by creating internal `EqualTo` expressions to compare the main case expression to the expressions in the when clauses. This approach is similar to the existing case **expressions**, which are parsed in the same way. The problem with this approach is that the main expression is unnecessarily evaluated N times, where N is the number of when clauses, which can be quite inefficient, for example if the expression is a complex query. Optimally, the main expression would be evaluated once, and then compared to the other expressions. I'm open to suggestions as to what the best approach to achieve this would be.

Simple case compares one expression (case variable) to others, until an equal one is found. Else clause is optional.
```
BEGIN
  CASE 1
    WHEN 1 THEN
      SELECT 1;
    WHEN 2 THEN
      SELECT 2;
    ELSE
      SELECT 3;
  END CASE;
END
```

Searched case evaluates boolean expressions. Else clause is optional.
```
BEGIN
  CASE
    WHEN 1 = 1 THEN
      SELECT 1;
    WHEN 2 IN (1,2,3) THEN
      SELECT 2;
    ELSE
      SELECT 3;
  END CASE;
END
```

### Why are the changes needed?
Case statements are currently not implemented in sql scripting.

### Does this PR introduce _any_ user-facing change?
Yes, users will now be able to use case statements in their sql scripts.

### How was this patch tested?
Tests for both simple and searched case statements are added to SqlScriptingParserSuite, SqlScriptingExecutionNodeSuite and SqlScriptingInterpreterSuite.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #47672 from dusantism-db/sql-scripting-case-statement.

Authored-by: Dušan Tišma <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
dusantism-db authored and MaxGekk committed Sep 13, 2024
1 parent aa54ed1 commit 5533c81
Show file tree
Hide file tree
Showing 8 changed files with 920 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ compoundStatement
| setStatementWithOptionalVarKeyword
| beginEndCompoundBlock
| ifElseStatement
| caseStatement
| whileStatement
| repeatStatement
| leaveStatement
Expand Down Expand Up @@ -98,6 +99,13 @@ iterateStatement
: ITERATE multipartIdentifier
;

caseStatement
: CASE (WHEN conditions+=booleanExpression THEN conditionalBodies+=compoundBody)+
(ELSE elseBody=compoundBody)? END CASE #searchedCaseStatement
| CASE caseVariable=expression (WHEN conditionExpressions+=expression THEN conditionalBodies+=compoundBody)+
(ELSE elseBody=compoundBody)? END CASE #simpleCaseStatement
;

singleStatement
: (statement|setResetStatement) SEMICOLON* EOF
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,52 @@ class AstBuilder extends DataTypeAstBuilder
WhileStatement(condition, body, Some(labelText))
}

override def visitSearchedCaseStatement(ctx: SearchedCaseStatementContext): CaseStatement = {
val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))

if (conditions.length != conditionalBodies.length) {
throw SparkException.internalError(
s"Mismatched number of conditions ${conditions.length} and condition bodies" +
s" ${conditionalBodies.length} in case statement")
}

CaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
}

override def visitSimpleCaseStatement(ctx: SimpleCaseStatementContext): CaseStatement = {
// uses EqualTo to compare the case variable(the main case expression)
// to the WHEN clause expressions
val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) {
SingleStatement(
Project(
Seq(Alias(EqualTo(expression(ctx.caseVariable), expression(expr)), "condition")()),
OneRowRelation()))
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))

if (conditions.length != conditionalBodies.length) {
throw SparkException.internalError(
s"Mismatched number of conditions ${conditions.length} and condition bodies" +
s" ${conditionalBodies.length} in case statement")
}

CaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
}

override def visitRepeatStatement(ctx: RepeatStatementContext): RepeatStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()
Expand Down Expand Up @@ -292,7 +338,7 @@ class AstBuilder extends DataTypeAstBuilder
case c: RepeatStatementContext
if Option(c.beginLabel()).isDefined &&
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
=> true
=> true
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,17 @@ case class LeaveStatement(label: String) extends CompoundPlanStatement
* @param label Label of the loop to iterate.
*/
case class IterateStatement(label: String) extends CompoundPlanStatement

/**
* Logical operator for CASE statement.
* @param conditions Collection of conditions which correspond to WHEN clauses.
* @param conditionalBodies Collection of bodies that have a corresponding condition,
* in WHEN branches.
* @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch.
*/
case class CaseStatement(
conditions: Seq[SingleStatement],
conditionalBodies: Seq[CompoundBody],
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
assert(conditions.length == conditionalBodies.length)
}
Loading

0 comments on commit 5533c81

Please sign in to comment.