Skip to content

Commit

Permalink
[SPARK-50330][SQL] Add hints to Sort and Window nodes
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Following #48523, this adds hints to Sort and Window nodes.

### Why are the changes needed?

Allows the users to specify concrete hints.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit test in SparkSessionExtensionSuite.scala

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #48812 from agubichev/sort_window_hint.

Authored-by: Andrey Gubichev <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
agubichev authored and cloud-fan committed Nov 18, 2024
1 parent b626528 commit a01856d
Show file tree
Hide file tree
Showing 20 changed files with 94 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
f.copy(condition = newCond)

// We should make sure all [[SortOrder]]s have been resolved.
case s @ Sort(order, _, child)
case s @ Sort(order, _, child, _)
if order.exists(hasGroupingFunction) && order.forall(_.resolved) =>
val groupingExprs = findGroupingExprs(child)
val gid = VirtualColumn.groupingIdAttribute
Expand Down Expand Up @@ -1815,7 +1815,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
case Sort(orders, global, child)
case Sort(orders, global, child, hint)
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
val newOrders = orders map {
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
Expand All @@ -1826,14 +1826,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
case o => o
}
Sort(newOrders, global, child)
Sort(newOrders, global, child, hint)

// Replace the index with the corresponding expression in aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns (select expression)
case Aggregate(groups, aggs, child, _) if aggs.forall(_.resolved) &&
case Aggregate(groups, aggs, child, hint) if aggs.forall(_.resolved) &&
groups.exists(containUnresolvedOrdinal) =>
val newGroups = groups.map(resolveGroupByExpressionOrdinal(_, aggs))
Aggregate(newGroups, aggs, child)
Aggregate(newGroups, aggs, child, hint)
}

private def containUnresolvedOrdinal(e: Expression): Boolean = e match {
Expand Down Expand Up @@ -2357,15 +2357,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Filter(newExprs.head, newChild)
})

case s @ Sort(_, _, agg: Aggregate) if agg.resolved && s.order.forall(_.resolved) =>
case s @ Sort(_, _, agg: Aggregate, _) if agg.resolved && s.order.forall(_.resolved) =>
resolveOperatorWithAggregate(s.order.map(_.child), agg, (newExprs, newChild) => {
val newSortOrder = s.order.zip(newExprs).map {
case (sortOrder, expr) => sortOrder.copy(child = expr)
}
s.copy(order = newSortOrder, child = newChild)
})

case s @ Sort(_, _, f @ Filter(cond, agg: Aggregate))
case s @ Sort(_, _, f @ Filter(cond, agg: Aggregate), _)
if agg.resolved && cond.resolved && s.order.forall(_.resolved) =>
resolveOperatorWithAggregate(s.order.map(_.child), agg, (newExprs, newChild) => {
val newSortOrder = s.order.zip(newExprs).map {
Expand Down Expand Up @@ -3618,10 +3618,10 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper {
val cleanedAggs = aggs.map(trimNonTopLevelAliases)
Aggregate(grouping.map(trimAliases), cleanedAggs, child, hint)

case Window(windowExprs, partitionSpec, orderSpec, child) =>
case Window(windowExprs, partitionSpec, orderSpec, child, hint) =>
val cleanedWindowExprs = windowExprs.map(trimNonTopLevelAliases)
Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child, hint)

case CollectMetrics(name, metrics, child, dataframeId) =>
val cleanedMetrics = metrics.map(trimNonTopLevelAliases)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case up: Unpivot if up.canBeCoercioned && !up.valuesTypeCoercioned =>
throw QueryCompilationErrors.unpivotValueDataTypeMismatchError(up.values.get)

case Sort(orders, _, _) =>
case Sort(orders, _, _, _) =>
orders.foreach { order =>
if (!RowOrdering.isOrderable(order.dataType)) {
order.failAnalysis(
Expand All @@ -612,7 +612,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
}
}

case Window(_, partitionSpec, _, _) =>
case Window(_, partitionSpec, _, _, _) =>
// Both `partitionSpec` and `orderSpec` must be orderable. We only need an extra check
// for `partitionSpec` here because `orderSpec` has the type check itself.
partitionSpec.foreach { p =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ Window(windowExpressions, _, _, child)
case oldVersion @ Window(windowExpressions, _, _, child, _)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
val newVersion = oldVersion.copy(windowExpressions = newAliases(windowExpressions))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,14 +484,14 @@ object UnsupportedOperationChecker extends Logging {

case Offset(_, _) => throwError("Offset is not supported on streaming DataFrames/Datasets")

case Sort(_, _, _) if !containsCompleteData(subPlan) =>
case Sort(_, _, _, _) if !containsCompleteData(subPlan) =>
throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " +
"aggregated DataFrame/Dataset in Complete output mode")

case Sample(_, _, _, _, child) if child.isStreaming =>
throwError("Sampling is not supported on streaming DataFrames/Datasets")

case Window(windowExpression, _, _, child) if child.isStreaming =>
case Window(windowExpression, _, _, child, _) if child.isStreaming =>
val (windowFuncList, columnNameList, windowSpecList) = windowExpression.flatMap { e =>
e.collect {
case we: WindowExpression =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
// of limit in that case. This branch is for the case where there's no limit operator
// above offset.
val (child, ordering) = input match {
case Sort(order, _, child) => (child, order)
case Sort(order, _, child, _) => (child, order)
case _ => (input, Seq())
}
val (newChild, joinCond, outerReferenceMap) =
Expand Down Expand Up @@ -705,8 +705,8 @@ object DecorrelateInnerQuery extends PredicateHelper {
// SELECT T2.a, row_number() OVER (PARTITION BY T2.b ORDER BY T2.c) AS rn FROM T2)
// WHERE rn > 2 AND rn <= 2+3
val (child, ordering, offsetExpr) = input match {
case Sort(order, _, child) => (child, order, Literal(0))
case Offset(offsetExpr, offsetChild@(Sort(order, _, child))) =>
case Sort(order, _, child, _) => (child, order, Literal(0))
case Offset(offsetExpr, offsetChild@(Sort(order, _, child, _))) =>
(child, order, offsetExpr)
case Offset(offsetExpr, child) =>
(child, Seq(), offsetExpr)
Expand Down Expand Up @@ -754,7 +754,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
(project, joinCond, outerReferenceMap)
}

case w @ Window(projectList, partitionSpec, orderSpec, child) =>
case w @ Window(projectList, partitionSpec, orderSpec, child, hint) =>
val outerReferences = collectOuterReferences(w.expressions)
assert(outerReferences.isEmpty, s"Correlated column is not allowed in window " +
s"function: $w")
Expand All @@ -770,7 +770,7 @@ object DecorrelateInnerQuery extends PredicateHelper {

val newWindow = Window(newProjectList ++ referencesToAdd,
partitionSpec = newPartitionSpec ++ referencesToAdd,
orderSpec = newOrderSpec, newChild)
orderSpec = newOrderSpec, newChild, hint)
(newWindow, joinCond, outerReferenceMap)

case a @ Aggregate(groupingExpressions, aggregateExpressions, child, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{WINDOW, WINDOW_EXPRESSIO
object EliminateWindowPartitions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(WINDOW), ruleId) {
case w @ Window(windowExprs, partitionSpec, _, _) if partitionSpec.exists(_.foldable) =>
case w @ Window(windowExprs, partitionSpec, _, _, _) if partitionSpec.exists(_.foldable) =>
val newWindowExprs = windowExprs.map(_.transformWithPruning(
_.containsPattern(WINDOW_EXPRESSION)) {
case windowExpr @ WindowExpression(_, wsd @ WindowSpecDefinition(ps, _, _))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ object InferWindowGroupLimit extends Rule[LogicalPlan] with PredicateHelper {

plan.transformWithPruning(_.containsAllPatterns(FILTER, WINDOW), ruleId) {
case filter @ Filter(condition,
window @ Window(windowExpressions, partitionSpec, orderSpec, child))
window @ Window(windowExpressions, partitionSpec, orderSpec, child, _))
if !child.isInstanceOf[WindowGroupLimit] && windowExpressions.forall(isExpandingWindow) &&
orderSpec.nonEmpty =>
val limits = windowExpressions.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ object LimitPushDownThroughWindow extends Rule[LogicalPlan] {
_.containsAllPatterns(WINDOW, LIMIT), ruleId) {
// Adding an extra Limit below WINDOW when the partitionSpec of all window functions is empty.
case LocalLimit(limitExpr @ IntegerLiteral(limit),
window @ Window(windowExpressions, Nil, orderSpec, child))
window @ Window(windowExpressions, Nil, orderSpec, child, _))
if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) &&
limit < conf.topKSortFallbackThreshold =>
// Sort is needed here because we need global sort.
window.copy(child = Limit(limitExpr, Sort(orderSpec, true, child)))
// There is a Project between LocalLimit and Window if they do not have the same output.
case LocalLimit(limitExpr @ IntegerLiteral(limit), project @ Project(_,
window @ Window(windowExpressions, Nil, orderSpec, child)))
window @ Window(windowExpressions, Nil, orderSpec, child, _)))
if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) &&
limit < conf.topKSortFallbackThreshold =>
// Sort is needed here because we need global sort.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ object OptimizeOneRowPlan extends Rule[LogicalPlan] {
val enableForStreaming = conf.getConf(SQLConf.STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED)

plan.transformUpWithPruning(_.containsAnyPattern(SORT, AGGREGATE), ruleId) {
case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) &&
case Sort(_, _, child, _) if child.maxRows.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) => child
case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) &&
case Sort(_, false, child, _) if child.maxRowsPerPartition.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) => child
case agg @ Aggregate(_, _, child, _) if agg.groupOnly && child.maxRows.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
return plan
}
plan match {
case Sort(_, _, child) => child
case Sort(_, _, child, _) => child
case Project(fields, child) => Project(fields, removeTopLevelSort(child))
case other => other
}
Expand Down Expand Up @@ -1303,7 +1303,7 @@ object CollapseRepartition extends Rule[LogicalPlan] {
// Case 2: When a RepartitionByExpression has a child of global Sort, Repartition or
// RepartitionByExpression we can remove the child.
case r @ RepartitionByExpression(
_, child @ (Sort(_, true, _) | _: RepartitionOperation), _, _) =>
_, child @ (Sort(_, true, _, _) | _: RepartitionOperation), _, _) =>
r.withNewChildren(child.children)
// Case 3: When a RebalancePartitions has a child of local or global Sort, Repartition or
// RepartitionByExpression we can remove the child.
Expand Down Expand Up @@ -1370,11 +1370,11 @@ object CollapseWindow extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(WINDOW), ruleId) {
case w1 @ Window(we1, _, _, w2 @ Window(we2, _, _, grandChild))
case w1 @ Window(we1, _, _, w2 @ Window(we2, _, _, grandChild, _), _)
if windowsCompatible(w1, w2) =>
w1.copy(windowExpressions = we2 ++ we1, child = grandChild)

case w1 @ Window(we1, _, _, Project(pl, w2 @ Window(we2, _, _, grandChild)))
case w1 @ Window(we1, _, _, Project(pl, w2 @ Window(we2, _, _, grandChild, _)), _)
if windowsCompatible(w1, w2) && w1.references.subsetOf(grandChild.outputSet) =>
Project(
pl ++ w1.windowOutputSet,
Expand Down Expand Up @@ -1403,11 +1403,11 @@ object TransposeWindow extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(WINDOW), ruleId) {
case w1 @ Window(_, _, _, w2 @ Window(_, _, _, grandChild))
case w1 @ Window(_, _, _, w2 @ Window(_, _, _, grandChild, _), _)
if windowsCompatible(w1, w2) =>
Project(w1.output, w2.copy(child = w1.copy(child = grandChild)))

case w1 @ Window(_, _, _, Project(pl, w2 @ Window(_, _, _, grandChild)))
case w1 @ Window(_, _, _, Project(pl, w2 @ Window(_, _, _, grandChild, _)), _)
if windowsCompatible(w1, w2) && w1.references.subsetOf(grandChild.outputSet) =>
Project(
pl ++ w1.windowOutputSet,
Expand Down Expand Up @@ -1649,14 +1649,14 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
*/
object EliminateSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(SORT)) {
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
case s @ Sort(orders, _, child, _) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
if (newOrders.isEmpty) {
child
} else {
s.copy(order = newOrders)
}
case s @ Sort(_, global, child) => s.copy(child = recursiveRemoveSort(child, global))
case s @ Sort(_, global, child, _) => s.copy(child = recursiveRemoveSort(child, global))
case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) =>
j.copy(left = recursiveRemoveSort(originLeft, true),
right = recursiveRemoveSort(originRight, true))
Expand All @@ -1675,7 +1675,7 @@ object EliminateSorts extends Rule[LogicalPlan] {
return plan
}
plan match {
case Sort(_, global, child) if canRemoveGlobalSort || !global =>
case Sort(_, global, child, _) if canRemoveGlobalSort || !global =>
recursiveRemoveSort(child, canRemoveGlobalSort)
case other if canEliminateSort(other) =>
other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, canRemoveGlobalSort)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ object RemoveRedundantSorts extends Rule[LogicalPlan] {
return plan
}
plan match {
case s @ Sort(orders, false, child) =>
case s @ Sort(orders, false, child, _) =>
if (SortOrder.orderingSatisfies(child.outputOrdering, orders)) {
recursiveRemoveSort(child, optimizeGlobalSort = false)
} else {
s.withNewChildren(Seq(recursiveRemoveSort(child, optimizeGlobalSort = true)))
}

case s @ Sort(orders, true, child) =>
case s @ Sort(orders, true, child, _) =>
val newChild = recursiveRemoveSort(child, optimizeGlobalSort = false)
if (optimizeGlobalSort) {
// For this case, the upper sort is local so the ordering of present sort is unnecessary,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ object PhysicalWindow {
(WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan)

def unapply(a: Any): Option[ReturnType] = a match {
case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) =>
case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child, _) =>

// The window expression should not be empty here, otherwise it's a bug.
if (windowExpressions.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,8 @@ case class WithWindowDefinition(
case class Sort(
order: Seq[SortOrder],
global: Boolean,
child: LogicalPlan) extends UnaryNode {
child: LogicalPlan,
hint: Option[SortHint] = None) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = child.maxRows
override def maxRowsPerPartition: Option[Long] = {
Expand Down Expand Up @@ -1266,7 +1267,8 @@ case class Window(
windowExpressions: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: LogicalPlan) extends UnaryNode {
child: LogicalPlan,
hint: Option[WindowHint] = None) extends UnaryNode {
override def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] =
child.output ++ windowExpressions.map(_.toAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ case object NO_BROADCAST_AND_REPLICATION extends JoinStrategyHint {

abstract class AggregateHint;

abstract class WindowHint;

abstract class SortHint;

/**
* The callback for implementing customized strategies of handling hint errors.
*/
Expand Down
Loading

0 comments on commit a01856d

Please sign in to comment.