Skip to content

Commit

Permalink
[SPARK-45009][SQL] Decorrelate predicate subqueries in join condition
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Pulling up correlated subquery predicates in Joins, and re-writing them into ExistenceJoins if they are not pushed down into the join inputs.

### Why are the changes needed?

This change allows correlated IN and EXISTS subqueries in join condition. This is valid SQL that is not yet supported by Spark SQL.

### Does this PR introduce _any_ user-facing change?

Yes, previously unsupported queries become supported.

### How was this patch tested?

Added SQL tests for IN and EXISTS in join conditions, and crossed-check correctness with postgres (except for ANTI joins, which are not supported in postgres).

Permutations of the tests:
1. Exists / Not exists / in / not in
2. Subquery references left child / right child
3. Join type: inner / left outer
4. Transitive predicates to try invoking filter inference

Closes #42725 from andylam-db/correlated-subquery-in-join-cond.

Authored-by: Andy Lam <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
andylam-db authored and cloud-fan committed Oct 16, 2023
1 parent d8dbb66 commit 4fd2d68
Show file tree
Hide file tree
Showing 13 changed files with 3,203 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
Set(
"PartitionPruning",
"RewriteSubquery",
"Extract Python UDFs")
"Extract Python UDFs",
"Infer Filters")

protected def fixedPoint =
FixedPoint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, IN_SUBQUERY, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION, OPTIMIZE_UNCORRELATED_IN_SUBQUERIES_IN_JOIN_CONDITION}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -176,6 +177,71 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
Project(p.output, Filter(newCond.get, inputPlan))
}

// This case takes care of predicate subqueries in join conditions that are not pushed down
// to the children nodes by [[PushDownPredicates]].
case j: Join if j.condition.exists(cond =>
SubqueryExpression.hasInOrCorrelatedExistsSubquery(cond)) &&
conf.getConf(DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION) =>

val optimizeUncorrelatedInSubqueries =
conf.getConf(OPTIMIZE_UNCORRELATED_IN_SUBQUERIES_IN_JOIN_CONDITION)
val relevantSubqueries = j.condition.get.collect {
case i: InSubquery if i.query.isCorrelated => i
case i: InSubquery if !i.query.isCorrelated && optimizeUncorrelatedInSubqueries => i
case e: Exists if e.isCorrelated => e
}
if (relevantSubqueries.isEmpty) {
j
} else {
// `subqueriesWithJoinInputReferenceInfo`is of type Seq[(Expression, Boolean, Boolean)]
// (1): Expression, the join predicate containing some predicate subquery we are interested
// in re-writing
// (2): Boolean, whether (1) references the left join input
// (3): Boolean, whether (1) references the right join input
val subqueriesWithJoinInputReferenceInfo = relevantSubqueries.map { e =>
val referenceLeft = e.references.intersect(j.left.outputSet).nonEmpty
val referenceRight = e.references.intersect(j.right.outputSet).nonEmpty
(e, referenceLeft, referenceRight)
}
val subqueriesReferencingBothJoinInputs = subqueriesWithJoinInputReferenceInfo
.filter(i => i._2 && i._3)

// Currently do not support correlated subqueries in the join predicate that reference both
// join inputs
if (subqueriesReferencingBothJoinInputs.nonEmpty) {
throw QueryCompilationErrors.unsupportedCorrelatedSubqueryInJoinConditionError(
subqueriesReferencingBothJoinInputs.map(_._1))
}
val subqueriesReferencingLeft = subqueriesWithJoinInputReferenceInfo.filter(_._2).map(_._1)
val subqueriesReferencingRight = subqueriesWithJoinInputReferenceInfo.filter(_._3).map(_._1)
if (subqueriesReferencingLeft.isEmpty && subqueriesReferencingRight.isEmpty) {
j
} else {
var newCondition = j.condition.get
val newLeft = subqueriesReferencingLeft.foldLeft(j.left) {
case (p, e) =>
val (newCond, newInputPlan) = rewriteExistentialExpr(Seq(e), p)
// Update the join condition to rewrite the subquery expression
newCondition = newCondition.transform {
case expr if expr.fastEquals(e) => newCond.get
}
newInputPlan
}
val newRight = subqueriesReferencingRight.foldLeft(j.right) {
case (p, e) =>
val (newCond, newInputPlan) = rewriteExistentialExpr(Seq(e), p)
// Update the join condition to rewrite the subquery expression
newCondition = newCondition.transform {
case expr if expr.fastEquals(e) => newCond.get
}
newInputPlan
}
// Remove unwanted exists columns from new existence joins with new Project
Project(j.output, j.copy(left = newLeft, right = newRight,
condition = Some(newCondition)))
}
}

case u: UnaryNode if u.expressions.exists(
SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
var newChild = u.child
Expand Down Expand Up @@ -410,6 +476,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
} else {
newPlan
}
case j: Join if conf.getConf(DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION) =>
rewriteSubQueries(j)
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
case q: UnaryNode =>
rewriteSubQueries(q)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2099,6 +2099,15 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map("expr" -> expr.sql, "dataType" -> dataType.typeName))
}

def unsupportedCorrelatedSubqueryInJoinConditionError(
unsupportedSubqueryExpressions: Seq[Expression]): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"UNSUPPORTED_CORRELATED_EXPRESSION_IN_JOIN_CONDITION",
messageParameters = Map("subqueryExpression" ->
unsupportedSubqueryExpressions.map(_.sql).mkString(", ")))
}

def functionCannotProcessInputError(
unbound: UnboundFunction,
arguments: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4455,6 +4455,25 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val DECORRELATE_PREDICATE_SUBQUERIES_IN_JOIN_CONDITION =
buildConf("spark.sql.optimizer.decorrelatePredicateSubqueriesInJoinPredicate.enabled")
.internal()
.doc("Decorrelate predicate (in and exists) subqueries with correlated references in join " +
"predicates.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val OPTIMIZE_UNCORRELATED_IN_SUBQUERIES_IN_JOIN_CONDITION =
buildConf("spark.sql.optimizer.optimizeUncorrelatedInSubqueriesInJoinCondition.enabled")
.internal()
.doc("When true, optimize uncorrelated IN subqueries in join predicates by rewriting them " +
s"to joins. This interacts with ${LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key} because it " +
"can rewrite IN predicates.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val LEGACY_PERCENTILE_DISC_CALCULATION = buildConf("spark.sql.legacy.percentileDiscCalculation")
.internal()
.doc("If true, the old bogus percentile_disc calculation is used. The old calculation " +
Expand Down
Loading

0 comments on commit 4fd2d68

Please sign in to comment.