diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 23ecb54c095f5..bed7bea61597f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -700,7 +700,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor f.copy(condition = newCond) // We should make sure all [[SortOrder]]s have been resolved. - case s @ Sort(order, _, child) + case s @ Sort(order, _, child, _) if order.exists(hasGroupingFunction) && order.forall(_.resolved) => val groupingExprs = findGroupingExprs(child) val gid = VirtualColumn.groupingIdAttribute @@ -1815,7 +1815,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. - case Sort(orders, global, child) + case Sort(orders, global, child, hint) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => @@ -1826,14 +1826,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } case o => o } - Sort(newOrders, global, child) + Sort(newOrders, global, child, hint) // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) - case Aggregate(groups, aggs, child, _) if aggs.forall(_.resolved) && + case Aggregate(groups, aggs, child, hint) if aggs.forall(_.resolved) && groups.exists(containUnresolvedOrdinal) => val newGroups = groups.map(resolveGroupByExpressionOrdinal(_, aggs)) - Aggregate(newGroups, aggs, child) + Aggregate(newGroups, aggs, child, hint) } private def containUnresolvedOrdinal(e: Expression): Boolean = e match { @@ -2357,7 +2357,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Filter(newExprs.head, newChild) }) - case s @ Sort(_, _, agg: Aggregate) if agg.resolved && s.order.forall(_.resolved) => + case s @ Sort(_, _, agg: Aggregate, _) if agg.resolved && s.order.forall(_.resolved) => resolveOperatorWithAggregate(s.order.map(_.child), agg, (newExprs, newChild) => { val newSortOrder = s.order.zip(newExprs).map { case (sortOrder, expr) => sortOrder.copy(child = expr) @@ -2365,7 +2365,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor s.copy(order = newSortOrder, child = newChild) }) - case s @ Sort(_, _, f @ Filter(cond, agg: Aggregate)) + case s @ Sort(_, _, f @ Filter(cond, agg: Aggregate), _) if agg.resolved && cond.resolved && s.order.forall(_.resolved) => resolveOperatorWithAggregate(s.order.map(_.child), agg, (newExprs, newChild) => { val newSortOrder = s.order.zip(newExprs).map { @@ -3618,10 +3618,10 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { val cleanedAggs = aggs.map(trimNonTopLevelAliases) Aggregate(grouping.map(trimAliases), cleanedAggs, child, hint) - case Window(windowExprs, partitionSpec, orderSpec, child) => + case Window(windowExprs, partitionSpec, orderSpec, child, hint) => val cleanedWindowExprs = windowExprs.map(trimNonTopLevelAliases) Window(cleanedWindowExprs, partitionSpec.map(trimAliases), - orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) + orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child, hint) case CollectMetrics(name, metrics, child, dataframeId) => val cleanedMetrics = metrics.map(trimNonTopLevelAliases) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c7e5fa9f2b6c6..586a0312e1507 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -603,7 +603,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case up: Unpivot if up.canBeCoercioned && !up.valuesTypeCoercioned => throw QueryCompilationErrors.unpivotValueDataTypeMismatchError(up.values.get) - case Sort(orders, _, _) => + case Sort(orders, _, _, _) => orders.foreach { order => if (!RowOrdering.isOrderable(order.dataType)) { order.failAnalysis( @@ -612,7 +612,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB } } - case Window(_, partitionSpec, _, _) => + case Window(_, partitionSpec, _, _, _) => // Both `partitionSpec` and `orderSpec` must be orderable. We only need an extra check // for `partitionSpec` here because `orderSpec` has the type check itself. partitionSpec.foreach { p => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index ca5a6eee9bc9d..c1535343d7686 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -423,7 +423,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] { newVersion.copyTagsFrom(oldVersion) Seq((oldVersion, newVersion)) - case oldVersion @ Window(windowExpressions, _, _, child) + case oldVersion @ Window(windowExpressions, _, _, child, _) if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => val newVersion = oldVersion.copy(windowExpressions = newAliases(windowExpressions)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 69639b69290c2..4f33c26d5c3c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -484,14 +484,14 @@ object UnsupportedOperationChecker extends Logging { case Offset(_, _) => throwError("Offset is not supported on streaming DataFrames/Datasets") - case Sort(_, _, _) if !containsCompleteData(subPlan) => + case Sort(_, _, _, _) if !containsCompleteData(subPlan) => throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + "aggregated DataFrame/Dataset in Complete output mode") case Sample(_, _, _, _, child) if child.isStreaming => throwError("Sampling is not supported on streaming DataFrames/Datasets") - case Window(windowExpression, _, _, child) if child.isStreaming => + case Window(windowExpression, _, _, child, _) if child.isStreaming => val (windowFuncList, columnNameList, windowSpecList) = windowExpression.flatMap { e => e.collect { case we: WindowExpression => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala index 9758f37efc2dc..2b97b2621b5be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -662,7 +662,7 @@ object DecorrelateInnerQuery extends PredicateHelper { // of limit in that case. This branch is for the case where there's no limit operator // above offset. val (child, ordering) = input match { - case Sort(order, _, child) => (child, order) + case Sort(order, _, child, _) => (child, order) case _ => (input, Seq()) } val (newChild, joinCond, outerReferenceMap) = @@ -705,8 +705,8 @@ object DecorrelateInnerQuery extends PredicateHelper { // SELECT T2.a, row_number() OVER (PARTITION BY T2.b ORDER BY T2.c) AS rn FROM T2) // WHERE rn > 2 AND rn <= 2+3 val (child, ordering, offsetExpr) = input match { - case Sort(order, _, child) => (child, order, Literal(0)) - case Offset(offsetExpr, offsetChild@(Sort(order, _, child))) => + case Sort(order, _, child, _) => (child, order, Literal(0)) + case Offset(offsetExpr, offsetChild@(Sort(order, _, child, _))) => (child, order, offsetExpr) case Offset(offsetExpr, child) => (child, Seq(), offsetExpr) @@ -754,7 +754,7 @@ object DecorrelateInnerQuery extends PredicateHelper { (project, joinCond, outerReferenceMap) } - case w @ Window(projectList, partitionSpec, orderSpec, child) => + case w @ Window(projectList, partitionSpec, orderSpec, child, hint) => val outerReferences = collectOuterReferences(w.expressions) assert(outerReferences.isEmpty, s"Correlated column is not allowed in window " + s"function: $w") @@ -770,7 +770,7 @@ object DecorrelateInnerQuery extends PredicateHelper { val newWindow = Window(newProjectList ++ referencesToAdd, partitionSpec = newPartitionSpec ++ referencesToAdd, - orderSpec = newOrderSpec, newChild) + orderSpec = newOrderSpec, newChild, hint) (newWindow, joinCond, outerReferenceMap) case a @ Aggregate(groupingExpressions, aggregateExpressions, child, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateWindowPartitions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateWindowPartitions.scala index e3d1b05443583..ca3ee12a3d7db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateWindowPartitions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateWindowPartitions.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{WINDOW, WINDOW_EXPRESSIO object EliminateWindowPartitions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsPattern(WINDOW), ruleId) { - case w @ Window(windowExprs, partitionSpec, _, _) if partitionSpec.exists(_.foldable) => + case w @ Window(windowExprs, partitionSpec, _, _, _) if partitionSpec.exists(_.foldable) => val newWindowExprs = windowExprs.map(_.transformWithPruning( _.containsPattern(WINDOW_EXPRESSION)) { case windowExpr @ WindowExpression(_, wsd @ WindowSpecDefinition(ps, _, _)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala index f2e99721e9261..46815969e7ece 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala @@ -74,7 +74,7 @@ object InferWindowGroupLimit extends Rule[LogicalPlan] with PredicateHelper { plan.transformWithPruning(_.containsAllPatterns(FILTER, WINDOW), ruleId) { case filter @ Filter(condition, - window @ Window(windowExpressions, partitionSpec, orderSpec, child)) + window @ Window(windowExpressions, partitionSpec, orderSpec, child, _)) if !child.isInstanceOf[WindowGroupLimit] && windowExpressions.forall(isExpandingWindow) && orderSpec.nonEmpty => val limits = windowExpressions.collect { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala index de1368d28168f..c73d6ad8fa956 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushDownThroughWindow.scala @@ -43,14 +43,14 @@ object LimitPushDownThroughWindow extends Rule[LogicalPlan] { _.containsAllPatterns(WINDOW, LIMIT), ruleId) { // Adding an extra Limit below WINDOW when the partitionSpec of all window functions is empty. case LocalLimit(limitExpr @ IntegerLiteral(limit), - window @ Window(windowExpressions, Nil, orderSpec, child)) + window @ Window(windowExpressions, Nil, orderSpec, child, _)) if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) && limit < conf.topKSortFallbackThreshold => // Sort is needed here because we need global sort. window.copy(child = Limit(limitExpr, Sort(orderSpec, true, child))) // There is a Project between LocalLimit and Window if they do not have the same output. case LocalLimit(limitExpr @ IntegerLiteral(limit), project @ Project(_, - window @ Window(windowExpressions, Nil, orderSpec, child))) + window @ Window(windowExpressions, Nil, orderSpec, child, _))) if supportsPushdownThroughWindow(windowExpressions) && child.maxRows.forall(_ > limit) && limit < conf.topKSortFallbackThreshold => // Sort is needed here because we need global sort. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala index 8e066d1cd6340..6f732b2d0f20a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala @@ -46,9 +46,9 @@ object OptimizeOneRowPlan extends Rule[LogicalPlan] { val enableForStreaming = conf.getConf(SQLConf.STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED) plan.transformUpWithPruning(_.containsAnyPattern(SORT, AGGREGATE), ruleId) { - case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) && + case Sort(_, _, child, _) if child.maxRows.exists(_ <= 1L) && isChildEligible(child, enableForStreaming) => child - case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) && + case Sort(_, false, child, _) if child.maxRowsPerPartition.exists(_ <= 1L) && isChildEligible(child, enableForStreaming) => child case agg @ Aggregate(_, _, child, _) if agg.groupOnly && child.maxRows.exists(_ <= 1L) && isChildEligible(child, enableForStreaming) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 90d9bd5d5d88e..29216523fefc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -334,7 +334,7 @@ abstract class Optimizer(catalogManager: CatalogManager) return plan } plan match { - case Sort(_, _, child) => child + case Sort(_, _, child, _) => child case Project(fields, child) => Project(fields, removeTopLevelSort(child)) case other => other } @@ -1303,7 +1303,7 @@ object CollapseRepartition extends Rule[LogicalPlan] { // Case 2: When a RepartitionByExpression has a child of global Sort, Repartition or // RepartitionByExpression we can remove the child. case r @ RepartitionByExpression( - _, child @ (Sort(_, true, _) | _: RepartitionOperation), _, _) => + _, child @ (Sort(_, true, _, _) | _: RepartitionOperation), _, _) => r.withNewChildren(child.children) // Case 3: When a RebalancePartitions has a child of local or global Sort, Repartition or // RepartitionByExpression we can remove the child. @@ -1370,11 +1370,11 @@ object CollapseWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsPattern(WINDOW), ruleId) { - case w1 @ Window(we1, _, _, w2 @ Window(we2, _, _, grandChild)) + case w1 @ Window(we1, _, _, w2 @ Window(we2, _, _, grandChild, _), _) if windowsCompatible(w1, w2) => w1.copy(windowExpressions = we2 ++ we1, child = grandChild) - case w1 @ Window(we1, _, _, Project(pl, w2 @ Window(we2, _, _, grandChild))) + case w1 @ Window(we1, _, _, Project(pl, w2 @ Window(we2, _, _, grandChild, _)), _) if windowsCompatible(w1, w2) && w1.references.subsetOf(grandChild.outputSet) => Project( pl ++ w1.windowOutputSet, @@ -1403,11 +1403,11 @@ object TransposeWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsPattern(WINDOW), ruleId) { - case w1 @ Window(_, _, _, w2 @ Window(_, _, _, grandChild)) + case w1 @ Window(_, _, _, w2 @ Window(_, _, _, grandChild, _), _) if windowsCompatible(w1, w2) => Project(w1.output, w2.copy(child = w1.copy(child = grandChild))) - case w1 @ Window(_, _, _, Project(pl, w2 @ Window(_, _, _, grandChild))) + case w1 @ Window(_, _, _, Project(pl, w2 @ Window(_, _, _, grandChild, _)), _) if windowsCompatible(w1, w2) && w1.references.subsetOf(grandChild.outputSet) => Project( pl ++ w1.windowOutputSet, @@ -1649,14 +1649,14 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { */ object EliminateSorts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(SORT)) { - case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) => + case s @ Sort(orders, _, child, _) if orders.isEmpty || orders.exists(_.child.foldable) => val newOrders = orders.filterNot(_.child.foldable) if (newOrders.isEmpty) { child } else { s.copy(order = newOrders) } - case s @ Sort(_, global, child) => s.copy(child = recursiveRemoveSort(child, global)) + case s @ Sort(_, global, child, _) => s.copy(child = recursiveRemoveSort(child, global)) case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) => j.copy(left = recursiveRemoveSort(originLeft, true), right = recursiveRemoveSort(originRight, true)) @@ -1675,7 +1675,7 @@ object EliminateSorts extends Rule[LogicalPlan] { return plan } plan match { - case Sort(_, global, child) if canRemoveGlobalSort || !global => + case Sort(_, global, child, _) if canRemoveGlobalSort || !global => recursiveRemoveSort(child, canRemoveGlobalSort) case other if canEliminateSort(other) => other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, canRemoveGlobalSort))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala index 204d2a34675bc..3923b9b1b7fae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala @@ -36,14 +36,14 @@ object RemoveRedundantSorts extends Rule[LogicalPlan] { return plan } plan match { - case s @ Sort(orders, false, child) => + case s @ Sort(orders, false, child, _) => if (SortOrder.orderingSatisfies(child.outputOrdering, orders)) { recursiveRemoveSort(child, optimizeGlobalSort = false) } else { s.withNewChildren(Seq(recursiveRemoveSort(child, optimizeGlobalSort = true))) } - case s @ Sort(orders, true, child) => + case s @ Sort(orders, true, child, _) => val newChild = recursiveRemoveSort(child, optimizeGlobalSort = false) if (optimizeGlobalSort) { // For this case, the upper sort is local so the ordering of present sort is unnecessary, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index a666b977030e6..d0a0fc307756c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -364,7 +364,7 @@ object PhysicalWindow { (WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan) def unapply(a: Any): Option[ReturnType] = a match { - case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) => + case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child, _) => // The window expression should not be empty here, otherwise it's a bug. if (windowExpressions.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index dc286183ac689..0cb04064a6178 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -946,7 +946,8 @@ case class WithWindowDefinition( case class Sort( order: Seq[SortOrder], global: Boolean, - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan, + hint: Option[SortHint] = None) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows override def maxRowsPerPartition: Option[Long] = { @@ -1266,7 +1267,8 @@ case class Window( windowExpressions: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan, + hint: Option[WindowHint] = None) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output ++ windowExpressions.map(_.toAttribute) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 82260755977f0..c8d2be2987457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -199,6 +199,10 @@ case object NO_BROADCAST_AND_REPLICATION extends JoinStrategyHint { abstract class AggregateHint; +abstract class WindowHint; + +abstract class SortHint; + /** * The callback for implementing customized strategies of handling hint errors. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 96f3ddb72f054..22082aca81a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -106,28 +106,28 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { private def planTakeOrdered(plan: LogicalPlan): Option[SparkPlan] = plan match { // We should match the combination of limit and offset first, to get the optimal physical // plan, instead of planning limit and offset separately. - case LimitAndOffset(limit, offset, Sort(order, true, child)) + case LimitAndOffset(limit, offset, Sort(order, true, child, _)) if limit < conf.topKSortFallbackThreshold => Some(TakeOrderedAndProjectExec( limit, order, child.output, planLater(child), offset)) - case LimitAndOffset(limit, offset, Project(projectList, Sort(order, true, child))) + case LimitAndOffset(limit, offset, Project(projectList, Sort(order, true, child, _))) if limit < conf.topKSortFallbackThreshold => Some(TakeOrderedAndProjectExec( limit, order, projectList, planLater(child), offset)) // 'Offset a' then 'Limit b' is the same as 'Limit a + b' then 'Offset a'. - case OffsetAndLimit(offset, limit, Sort(order, true, child)) + case OffsetAndLimit(offset, limit, Sort(order, true, child, _)) if offset + limit < conf.topKSortFallbackThreshold => Some(TakeOrderedAndProjectExec( offset + limit, order, child.output, planLater(child), offset)) - case OffsetAndLimit(offset, limit, Project(projectList, Sort(order, true, child))) + case OffsetAndLimit(offset, limit, Project(projectList, Sort(order, true, child, _))) if offset + limit < conf.topKSortFallbackThreshold => Some(TakeOrderedAndProjectExec( offset + limit, order, projectList, planLater(child), offset)) - case Limit(IntegerLiteral(limit), Sort(order, true, child)) + case Limit(IntegerLiteral(limit), Sort(order, true, child, _)) if limit < conf.topKSortFallbackThreshold => Some(TakeOrderedAndProjectExec( limit, order, child.output, planLater(child))) - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child, _))) if limit < conf.topKSortFallbackThreshold => Some(TakeOrderedAndProjectExec( limit, order, projectList, planLater(child))) @@ -978,7 +978,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } - case logical.Sort(sortExprs, global, child) => + case logical.Sort(sortExprs, global, child, _) => execution.SortExec(sortExprs, global, planLater(child)) :: Nil case logical.Project(projectList, child) => execution.ProjectExec(projectList, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 2b6fcd9d547f1..23b2647b62a19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -420,7 +420,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { sHolder.pushedLimit = Some(limit) } (operation, isPushed && !isPartiallyPushed) - case s @ Sort(order, _, operation @ PhysicalOperation(project, Nil, sHolder: ScanBuilderHolder)) + case s @ Sort(order, _, + operation @ PhysicalOperation(project, Nil, sHolder: ScanBuilderHolder), _) if CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) => val aliasMap = getAliasMap(project) val aliasReplacedOrder = order.map(replaceAlias(_, aliasMap)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala index 0872efd92002c..7daf2c6b1b58b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala @@ -114,7 +114,7 @@ trait SQLQueryTestHelper extends Logging { | _: DescribeColumnCommand | _: DescribeRelation | _: DescribeColumn => true - case PhysicalOperation(_, _, Sort(_, true, _)) => true + case PhysicalOperation(_, _, Sort(_, true, _, _)) => true case _ => plan.children.iterator.exists(isSemanticallySorted) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 5ec557462de10..8750c398cc942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, Max, Partial} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, CompoundBody, ParserInterface} import org.apache.spark.sql.catalyst.plans.{PlanTest, SQLHelper} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AggregateHint, ColumnStat, Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AggregateHint, ColumnStat, Limit, LocalRelation, LogicalPlan, Sort, SortHint, Statistics, UnresolvedHint} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -559,6 +559,17 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt compareExpressions(expectedAlias, res.asInstanceOf[Aggregate].aggregateExpressions.head) } } + + test("custom sort hint") { + // The custom hint allows us to replace the sort with its input + withSession(Seq(_.injectHintResolutionRule(CustomerSortHintResolutionRule), + _.injectOptimizerRule(CustomSortRule))) { session => + val res = session.range(10).sort("id") + .hint("INPUT_SORTED") + .queryExecution.optimizedPlan + assert(res.collect {case s: Sort => s}.isEmpty) + } + } } case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { @@ -1302,3 +1313,27 @@ case class CustomAggregateRule(spark: SparkSession) extends Rule[LogicalPlan] { } } } + +// Example of a Sort hint that tells that the input is already sorted, +// and the rule that removes all Sort nodes based on such hint. +case class CustomSortHint(inputSorted: Boolean) extends SortHint + +// Attaches the CustomSortHint to the sort node. +case class CustomerSortHintResolutionRule(spark: SparkSession) extends Rule[LogicalPlan] { + val MY_HINT_NAME = Set("INPUT_SORTED") + + private def applySortHint(plan: LogicalPlan): LogicalPlan = plan.transformDown { + case s @ Sort(_, _, _, None) => s.copy(hint = Some(CustomSortHint(true))) + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case h: UnresolvedHint if MY_HINT_NAME.contains(h.name.toUpperCase(Locale.ROOT)) => + applySortHint(h.child) + } +} + +case class CustomSortRule(spark: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case s @ Sort(_, _, _, Some(CustomSortHint(true))) => s.child + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala index a6bf95be837da..8c20a40fede72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala @@ -221,7 +221,7 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { _, false, RepartitionByExpression( _, Project( _, SubqueryAlias( - _, _: LocalRelation)), _, _)) => + _, _: LocalRelation)), _, _), _) => case other => failure(other) } @@ -235,7 +235,7 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { case Sort( _, false, Repartition( 1, true, SubqueryAlias( - _, _: LocalRelation))) => + _, _: LocalRelation)), _) => case other => failure(other) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 87e58bb8fa13a..73dda42568a71 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -162,7 +162,7 @@ abstract class HiveComparisonTest extends SparkFunSuite with BeforeAndAfterAll { def isSorted(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false - case PhysicalOperation(_, _, Sort(_, true, _)) => true + case PhysicalOperation(_, _, Sort(_, true, _, _)) => true case _ => plan.children.iterator.exists(isSorted) }