Skip to content

Commit

Permalink
[SPARK-49396][SQL] Modify nullability check for CaseWhen expression
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Previously, the nullability check of CaseWhen checks that
(1) either of the branches including elseValue is nullable or
(2) elseValue is None.

The pr changes this nullability check to add concerns for TrueLiteral branches.
If there are trueLiteral branches, the nullability check will only consider branches before first TrueLiteral branch and the value of TrueLiteral branch as later branches will never be invoked.

### Why are the changes needed?

This nullability check is more accurate and align with SimplifyConditional rule.

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

added unit test

### Was this patch authored or co-authored using generative AI tooling?

no

Closes #47981 from averyqi-db/SPARK-49396-fix-nullability-check.

Authored-by: Avery Qi <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
averyqi-db authored and cloud-fan committed Sep 6, 2024
1 parent f3d2ebd commit 2c1e69c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TernaryLike
Expand Down Expand Up @@ -188,6 +189,13 @@ case class CaseWhen(
}

override def nullable: Boolean = {
if (branches.exists(_._1 == TrueLiteral)) {
// if any of the branch is always true
// nullability check should only be related to branches
// before the TrueLiteral and value of the first TrueLiteral branch
val (h, t) = branches.span(_._1 != TrueLiteral)
return h.exists(_._2.nullable) || t.head._2.nullable
}
// Result is nullable if any of the branch is nullable, or if the else value is nullable
branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -277,4 +278,27 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
assert(!caseWhenObj1.semanticEquals(caseWhenObj2))
assert(!caseWhenObj2.semanticEquals(caseWhenObj1))
}

test("SPARK-49396 accurate nullability check") {
val trueBranch = (TrueLiteral, Literal(5))
val normalBranch = (NonFoldableLiteral(true), Literal(10))

val nullLiteral = Literal.create(null, BooleanType)
val noElseValue = CaseWhen(normalBranch :: trueBranch :: Nil, None)
assert(!noElseValue.nullable)
val withElseValue = CaseWhen(normalBranch :: trueBranch :: Nil, Some(Literal(1)))
assert(!withElseValue.nullable)
val withNullableElseValue = CaseWhen(normalBranch :: trueBranch :: Nil, Some(nullLiteral))
assert(!withNullableElseValue.nullable)
val firstTrueNonNullableSecondTrueNullable = CaseWhen(trueBranch ::
(TrueLiteral, nullLiteral) :: Nil, None)
assert(!firstTrueNonNullableSecondTrueNullable.nullable)
val firstTrueNullableSecondTrueNonNullable = CaseWhen((TrueLiteral, nullLiteral) ::
trueBranch :: Nil, None)
assert(firstTrueNullableSecondTrueNonNullable.nullable)
val hasNullInNotTrueBranch = CaseWhen(trueBranch :: (FalseLiteral, nullLiteral) :: Nil, None)
assert(!hasNullInNotTrueBranch.nullable)
val noTrueBranch = CaseWhen(normalBranch :: Nil, Literal(1))
assert(!noTrueBranch.nullable)
}
}

0 comments on commit 2c1e69c

Please sign in to comment.