Skip to content

Commit

Permalink
[SPARK-49653][SQL] Single join for correlated scalar subqueries
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Single join is a left outer join that checks that there is at most 1 build row for every probe row.

This PR adds single join implementation to support correlated scalar subqueries where the optimizer can't guarantee that 1 row is coming from them, e.g.:
select *, (select t1.x from t1 where t1.y >= t_outer.y) from t_outer.
-- this subquery is going to be rewritten as a single join that makes sure there is at most 1 matching build row for every probe row. It will issue a spark runtime error otherwise.

Design doc: https://docs.google.com/document/d/1NTsvtBTB9XvvyRvH62QzWIZuw4hXktALUG1fBP7ha1Q/edit

The optimizer introduces a single join in cases that were previously returning incorrect results (or were unsupported).
Only hash-based implementation is supported, the optimizer makes sure we don't plan a single join as a sort-merge join.

### Why are the changes needed?

Expands our subquery coverage.

### Does this PR introduce _any_ user-facing change?

Yes, previously unsupported scalar subqueries should now work.

### How was this patch tested?

Unit tests for the single join operator. Query tests for the subqueries.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48145 from agubichev/single_join.

Authored-by: Andrey Gubichev <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
agubichev authored and cloud-fan committed Sep 23, 2024
1 parent 0eeb61f commit 3c81f07
Show file tree
Hide file tree
Showing 25 changed files with 613 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2716,7 +2716,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
*/
private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
case s @ ScalarSubquery(sub, _, exprId, _, _, _) if !sub.resolved =>
case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved =>
resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId, _, _) if !sub.resolved =>
resolveSubQuery(e, outer)(Exists(_, _, exprId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -952,19 +952,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
messageParameters = Map.empty)
}

// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
// are not part of the correlated columns.

// Collect the inner query expressions that are guaranteed to have a single value for each
// outer row. See comment on getCorrelatedEquivalentInnerExpressions.
val correlatedEquivalentExprs = getCorrelatedEquivalentInnerExpressions(query)
// Grouping expressions, except outer refs and constant expressions - grouping by an
// outer ref or a constant is always ok
val groupByExprs =
ExpressionSet(agg.groupingExpressions.filter(x => !x.isInstanceOf[OuterReference] &&
x.references.nonEmpty))
val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs

val nonEquivalentGroupByExprs = nonEquivalentGroupbyCols(query, agg)
val invalidCols = if (!SQLConf.get.getConf(
SQLConf.LEGACY_SCALAR_SUBQUERY_ALLOW_GROUP_BY_NON_EQUALITY_CORRELATED_PREDICATE)) {
nonEquivalentGroupByExprs
Expand Down Expand Up @@ -1044,23 +1032,25 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
checkOuterReference(plan, expr)

expr match {
case ScalarSubquery(query, outerAttrs, _, _, _, _) =>
case ScalarSubquery(query, outerAttrs, _, _, _, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size,
expr.origin)
}

if (outerAttrs.nonEmpty) {
cleanQueryInScalarSubquery(query) match {
case a: Aggregate => checkAggregateInScalarSubquery(outerAttrs, query, a)
case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(outerAttrs, query, a)
case p: LogicalPlan if p.maxRows.exists(_ <= 1) => // Ok
case other =>
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY",
messageParameters = Map.empty)
if (!SQLConf.get.getConf(SQLConf.SCALAR_SUBQUERY_USE_SINGLE_JOIN)) {
cleanQueryInScalarSubquery(query) match {
case a: Aggregate => checkAggregateInScalarSubquery(outerAttrs, query, a)
case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(outerAttrs, query, a)
case p: LogicalPlan if p.maxRows.exists(_ <= 1) => // Ok
case other =>
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY",
messageParameters = Map.empty)
}
}

// Only certain operators are allowed to host subquery expression containing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,20 @@ object SubExprUtils extends PredicateHelper {
case _ => ExpressionSet().empty
}
}

// Returns grouping expressions of 'aggNode' of a scalar subquery that do not have equivalent
// columns in the outer query (bound by equality predicates like 'col = outer(c)').
// We use it to analyze whether a scalar subquery is guaranteed to return at most 1 row.
def nonEquivalentGroupbyCols(query: LogicalPlan, aggNode: Aggregate): ExpressionSet = {
val correlatedEquivalentExprs = getCorrelatedEquivalentInnerExpressions(query)
// Grouping expressions, except outer refs and constant expressions - grouping by an
// outer ref or a constant is always ok
val groupByExprs =
ExpressionSet(aggNode.groupingExpressions.filter(x => !x.isInstanceOf[OuterReference] &&
x.references.nonEmpty))
val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs
nonEquivalentGroupByExprs
}
}

/**
Expand All @@ -371,14 +385,20 @@ object SubExprUtils extends PredicateHelper {
* case the subquery yields no row at all on empty input to the GROUP BY, which evaluates to NULL.
* It is set in PullupCorrelatedPredicates to true/false, before it is set its value is None.
* See constructLeftJoins in RewriteCorrelatedScalarSubquery for more details.
*
* 'needSingleJoin' is set to true if we can't guarantee that the correlated scalar subquery
* returns at most 1 row. For such subqueries we use a modification of an outer join called
* LeftSingle join. This value is set in PullupCorrelatedPredicates and used in
* RewriteCorrelatedScalarSubquery.
*/
case class ScalarSubquery(
plan: LogicalPlan,
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
joinCond: Seq[Expression] = Seq.empty,
hint: Option[HintInfo] = None,
mayHaveCountBug: Option[Boolean] = None)
mayHaveCountBug: Option[Boolean] = None,
needSingleJoin: Option[Boolean] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
override def dataType: DataType = {
if (!plan.schema.fields.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
case d: DynamicPruningSubquery => d
case s @ ScalarSubquery(
PhysicalOperation(projections, predicates, a @ Aggregate(group, _, child)),
_, _, _, _, mayHaveCountBug)
_, _, _, _, mayHaveCountBug, _)
if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) &&
mayHaveCountBug.nonEmpty && mayHaveCountBug.get =>
// This is a subquery with an aggregate that may suffer from a COUNT bug.
Expand Down Expand Up @@ -1988,7 +1988,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
}

private def canPushThrough(joinType: JoinType): Boolean = joinType match {
case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftAnti | ExistenceJoin(_) => true
case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftSingle |
LeftAnti | ExistenceJoin(_) => true
case _ => false
}

Expand Down Expand Up @@ -2028,7 +2029,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {

(leftFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case LeftOuter | LeftExistence(_) =>
case LeftOuter | LeftSingle | LeftExistence(_) =>
// push down the left side only `where` condition
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
Expand Down Expand Up @@ -2074,6 +2075,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)

Join(newLeft, newRight, joinType, newJoinCond, hint)
// Do not move join predicates of a single join.
case LeftSingle => j

case other =>
throw SparkException.internalError(s"Unexpected join type: $other")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
}

// Don't replace ScalarSubquery if its plan is an aggregate that may suffer from a COUNT bug.
case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug)
case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _)
if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) &&
mayHaveCountBug.nonEmpty && mayHaveCountBug.get =>
s
Expand Down Expand Up @@ -1007,7 +1007,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], foldableMap)
val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType match {
case _: InnerLike | LeftExistence(_) => Nil
case LeftOuter => newJoin.right.output
case LeftOuter | LeftSingle => newJoin.right.output
case RightOuter => newJoin.left.output
case FullOuter => newJoin.left.output ++ newJoin.right.output
case _ => Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ trait JoinSelectionHelper extends Logging {
)
}

def getBroadcastNestedLoopJoinBuildSide(hint: JoinHint): Option[BuildSide] = {
if (hintToNotBroadcastAndReplicateLeft(hint)) {
def getBroadcastNestedLoopJoinBuildSide(hint: JoinHint, joinType: JoinType): Option[BuildSide] = {
if (hintToNotBroadcastAndReplicateLeft(hint) || joinType == LeftSingle) {
Some(BuildRight)
} else if (hintToNotBroadcastAndReplicateRight(hint)) {
Some(BuildLeft)
Expand Down Expand Up @@ -375,7 +375,7 @@ trait JoinSelectionHelper extends Logging {

def canBuildBroadcastRight(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true
case _: InnerLike | LeftOuter | LeftSingle | LeftSemi | LeftAnti | _: ExistenceJoin => true
case _ => false
}
}
Expand All @@ -389,7 +389,7 @@ trait JoinSelectionHelper extends Logging {

def canBuildShuffledHashJoinRight(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | LeftOuter | FullOuter | RightOuter |
case _: InnerLike | LeftOuter | LeftSingle | FullOuter | RightOuter |
LeftSemi | LeftAnti | _: ExistenceJoin => true
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,31 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
(newPlan, newCond)
}

// Returns true if 'query' is guaranteed to return at most 1 row.
private def guaranteedToReturnOneRow(query: LogicalPlan): Boolean = {
if (query.maxRows.exists(_ <= 1)) {
return true
}
val aggNode = query match {
case havingPart@Filter(_, aggPart: Aggregate) => Some(aggPart)
case aggPart: Aggregate => Some(aggPart)
// LIMIT 1 is handled above, this is for all other types of LIMITs
case Limit(_, aggPart: Aggregate) => Some(aggPart)
case Project(_, aggPart: Aggregate) => Some(aggPart)
case _: LogicalPlan => None
}
if (!aggNode.isDefined) {
return false
}
val aggregates = aggNode.get.expressions.flatMap(_.collect {
case a: AggregateExpression => a
})
if (aggregates.isEmpty) {
return false
}
nonEquivalentGroupbyCols(query, aggNode.get).isEmpty
}

private def rewriteSubQueries(plan: LogicalPlan): LogicalPlan = {
/**
* This function is used as a aid to enforce idempotency of pullUpCorrelatedPredicate rule.
Expand All @@ -481,7 +506,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
}

plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case ScalarSubquery(sub, children, exprId, conditions, hint, mayHaveCountBugOld)
case ScalarSubquery(sub, children, exprId, conditions, hint,
mayHaveCountBugOld, needSingleJoinOld)
if children.nonEmpty =>

def mayHaveCountBugAgg(a: Aggregate): Boolean = {
Expand Down Expand Up @@ -527,8 +553,13 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
val (topPart, havingNode, aggNode) = splitSubquery(sub)
(aggNode.isDefined && aggNode.get.groupingExpressions.isEmpty)
}
val needSingleJoin = if (needSingleJoinOld.isDefined) {
needSingleJoinOld.get
} else {
conf.getConf(SQLConf.SCALAR_SUBQUERY_USE_SINGLE_JOIN) && !guaranteedToReturnOneRow(sub)
}
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions),
hint, Some(mayHaveCountBug))
hint, Some(mayHaveCountBug), Some(needSingleJoin))
case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty =>
val (newPlan, newCond) = if (SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) {
decorrelate(sub, plan, handleCountBug = true)
Expand Down Expand Up @@ -786,17 +817,22 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = {
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
val newChild = subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug)) =>
case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug,
needSingleJoin)) =>
val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions)
val origOutput = query.output.head
// The subquery appears on the right side of the join, hence add its hint to the right
// of a join hint
val joinHint = JoinHint(None, subHint)

val resultWithZeroTups = evalSubqueryOnZeroTups(query)
val joinType = needSingleJoin match {
case Some(true) => LeftSingle
case _ => LeftOuter
}
lazy val planWithoutCountBug = Project(
currentChild.output :+ origOutput,
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), joinHint))
Join(currentChild, query, joinType, conditions.reduceOption(And), joinHint))

if (Utils.isTesting) {
assert(mayHaveCountBug.isDefined)
Expand Down Expand Up @@ -845,7 +881,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
currentChild.output :+ subqueryResultExpr,
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), joinHint))
joinType, conditions.reduceOption(And), joinHint))

} else {
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
Expand Down Expand Up @@ -877,7 +913,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
currentChild.output :+ caseExpr,
Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
LeftOuter, conditions.reduceOption(And), joinHint))
joinType, conditions.reduceOption(And), joinHint))
}
}
}
Expand Down Expand Up @@ -1028,7 +1064,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {

case p: LogicalPlan => p.transformExpressionsUpWithPruning(
_.containsPattern(SCALAR_SUBQUERY)) {
case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _)
case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _)
if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty =>
assert(p.projectList.size == 1)
stripOuterReferences(p.projectList).head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ case object LeftAnti extends JoinType {
override def sql: String = "LEFT ANTI"
}

case object LeftSingle extends JoinType {
override def sql: String = "LEFT SINGLE"
}

case class ExistenceJoin(exists: Attribute) extends JoinType {
override def sql: String = {
// This join type is only used in the end of optimizer and physical plans, we will not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,12 @@ case class Join(

override def maxRows: Option[Long] = {
joinType match {
case Inner | Cross | FullOuter | LeftOuter | RightOuter
case Inner | Cross | FullOuter | LeftOuter | RightOuter | LeftSingle
if left.maxRows.isDefined && right.maxRows.isDefined =>
val leftMaxRows = BigInt(left.maxRows.get)
val rightMaxRows = BigInt(right.maxRows.get)
val minRows = joinType match {
case LeftOuter => leftMaxRows
case LeftOuter | LeftSingle => leftMaxRows
case RightOuter => rightMaxRows
case FullOuter => leftMaxRows + rightMaxRows
case _ => BigInt(0)
Expand All @@ -590,7 +590,7 @@ case class Join(
left.output :+ j.exists
case LeftExistence(_) =>
left.output
case LeftOuter =>
case LeftOuter | LeftSingle =>
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
Expand Down Expand Up @@ -627,7 +627,7 @@ case class Join(
left.constraints.union(right.constraints)
case LeftExistence(_) =>
left.constraints
case LeftOuter =>
case LeftOuter | LeftSingle =>
left.constraints
case RightOuter =>
right.constraints
Expand Down Expand Up @@ -659,7 +659,7 @@ case class Join(
var patterns = Seq(JOIN)
joinType match {
case _: InnerLike => patterns = patterns :+ INNER_LIKE_JOIN
case LeftOuter | FullOuter | RightOuter => patterns = patterns :+ OUTER_JOIN
case LeftOuter | FullOuter | RightOuter | LeftSingle => patterns = patterns :+ OUTER_JOIN
case LeftSemiOrAnti(_) => patterns = patterns :+ LEFT_SEMI_OR_ANTI_JOIN
case NaturalJoin(_) | UsingJoin(_, _) => patterns = patterns :+ NATURAL_LIKE_JOIN
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2477,6 +2477,12 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
summary = getSummary(context))
}

def scalarSubqueryReturnsMultipleRows(): SparkRuntimeException = {
new SparkRuntimeException(
errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS",
messageParameters = Map.empty)
}

def comparatorReturnsNull(firstValue: String, secondValue: String): Throwable = {
new SparkException(
errorClass = "COMPARATOR_RETURNS_NULL",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5090,6 +5090,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val SCALAR_SUBQUERY_USE_SINGLE_JOIN =
buildConf("spark.sql.optimizer.scalarSubqueryUseSingleJoin")
.internal()
.doc("When set to true, use LEFT_SINGLE join for correlated scalar subqueries where " +
"optimizer can't prove that only 1 row will be returned")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val ALLOW_SUBQUERY_EXPRESSIONS_IN_LAMBDAS_AND_HIGHER_ORDER_FUNCTIONS =
buildConf("spark.sql.analyzer.allowSubqueryExpressionsInLambdasOrHigherOrderFunctions")
.internal()
Expand Down
Loading

0 comments on commit 3c81f07

Please sign in to comment.