Skip to content

Commit

Permalink
[SPARK-46632][SQL] Fix subexpression elimination when equivalent tern…
Browse files Browse the repository at this point in the history
…ary expressions have different children

### What changes were proposed in this pull request?
Remove unexpected exception thrown in `EquivalentExpressions.updateExprInMap()`. Equivalent expressions may contain different children, it should happen expression not in map and `useCount` is -1.
For example, before this PR will throw IllegalStateException
```
Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c")
      .selectExpr("case when a + b + c>3 then 1 when c + a + b>0 then 2 else 0 end as d").show()
```

### Why are the changes needed?
Bug fix.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
New unit test, before this PR will throw IllegalStateException: *** with use count: -1

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46135 from zml1206/SPARK-46632.

Authored-by: zml1206 <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
zml1206 authored and cloud-fan committed Aug 12, 2024
1 parent f33aa0a commit 2fb8dff
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ class EquivalentExpressions(
case _ =>
if (useCount > 0) {
map.put(wrapper, ExpressionStats(expr)(useCount))
} else {
// Should not happen
throw SparkException.internalError(
s"Cannot update expression: $expr in map: $map with use count: $useCount")
}
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
checkShortcut(Or(equal, Literal(true)), 1)
checkShortcut(Not(And(equal, Literal(false))), 1)
}

test("Equivalent ternary expressions have different children") {
val add1 = Add(Add(Literal(1), Literal(2)), Literal(3))
val add2 = Add(Add(Literal(3), Literal(1)), Literal(2))
val conditions1 = (GreaterThan(add1, Literal(3)), Literal(1)) ::
(GreaterThan(add2, Literal(0)), Literal(2)) :: Nil

val caseWhenExpr1 = CaseWhen(conditions1, Literal(0))
val equivalence1 = new EquivalentExpressions
equivalence1.addExprTree(caseWhenExpr1)
assert(equivalence1.getCommonSubexpressions.size == 1)
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down

0 comments on commit 2fb8dff

Please sign in to comment.