Skip to content

Commit

Permalink
[SPARK-41631][SQL] Support implicit lateral column alias resolution o…
Browse files Browse the repository at this point in the history
…n Aggregate

### What changes were proposed in this pull request?

This PR implements the implicit lateral column alias on `Aggregate` case. For example,
```sql
-- LCA in Aggregate. The avg_salary references an attribute defined by a previous alias
SELECT dept, average(salary) AS avg_salary, avg_salary + average(bonus)
FROM employee
GROUP BY dept
```

The high level implementation idea is to insert the `Project` node above, and falling back to the resolution of lateral alias of Project code path in the last PR.

* Phase 1: recognize resolved lateral alias, wrap the attributes referencing them with `LateralColumnAliasReference`
* Phase 2: when the `Aggregate` operator is resolved, it goes through the whole aggregation list, extracts the aggregation expressions and grouping expressions to keep them in this `Aggregate` node, and add a `Project` above with the original output. It doesn't do anything on `LateralColumnAliasReference`, but completely leave it to the Project in the future turns of this rule.

Example:
```
 // Before rewrite:
 Aggregate [dept#14] [dept#14 AS a#12, 'a + 1, avg(salary#16) AS b#13, 'b + avg(bonus#17)]
 +- Child [dept#14,name#15,salary#16,bonus#17]

 // After phase 1:
 Aggregate [dept#14] [dept#14 AS a#12, lca(a) + 1, avg(salary#16) AS b#13, lca(b) + avg(bonus#17)]
 +- Child [dept#14,name#15,salary#16,bonus#17]

 // After phase 2:
 Project [dept#14 AS a#12, lca(a) + 1, avg(salary)apache#26 AS b#13, lca(b) + avg(bonus)apache#27]
 +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)apache#26, avg(bonus#17) AS avg(bonus)apache#27, dept#14]
     +- Child [dept#14,name#15,salary#16,bonus#17]

 // Now the problem falls back to the lateral alias resolution in Project.
 // After future rounds of this rule:
 Project [a#12, a#12 + 1, b#13, b#13 + avg(bonus)apache#27]
 +- Project [dept#14 AS a#12, avg(salary)apache#26 AS b#13]
    +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)apache#26, avg(bonus#17) AS avg(bonus)apache#27, dept#14]
       +- Child [dept#14,name#15,salary#16,bonus#17]
```

Similar as the last PR (apache#38776), because lateral column alias has higher resolution priority than outer reference, it will try to resolve an `OuterReference` using lateral column alias, similar as an `UnresolvedAttribute`. If success, it strips `OuterReference` and also wraps it with `LateralColumnAliasReference`.

### Why are the changes needed?
Similar as stated in apache#38776.

### Does this PR introduce _any_ user-facing change?

Yes, as shown in the above example, it will be able to resolve lateral column alias in Aggregate.

### How was this patch tested?

Existing tests and newly added tests.

Closes apache#39040 from anchovYu/SPARK-27561-agg.

Authored-by: Xinyi Yu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
anchovYu authored and cloud-fan committed Dec 21, 2022
1 parent 9409465 commit fd6d226
Show file tree
Hide file tree
Showing 7 changed files with 674 additions and 102 deletions.
5 changes: 5 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,11 @@
"The target JDBC server does not support transactions and can only support ALTER TABLE with a single action."
]
},
"LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC" : {
"message" : [
"Referencing a lateral column alias <lca> in the aggregate function <aggFunc>."
]
},
"LATERAL_JOIN_USING" : {
"message" : [
"JOIN USING with LATERAL correlation."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1818,7 +1818,7 @@ class Analyzer(override val catalogManager: CatalogManager)
val aliases = aliasMap.get(u.nameParts.head).get
aliases.size match {
case n if n > 1 =>
throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n)
throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, n)
case n if n == 1 && aliases.head.alias.resolved =>
// Only resolved alias can be the lateral column alias
// The lateral alias can be a struct and have nested field, need to construct
Expand All @@ -1838,7 +1838,7 @@ class Analyzer(override val catalogManager: CatalogManager)
val aliases = aliasMap.get(nameParts.head).get
aliases.size match {
case n if n > 1 =>
throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n)
throw QueryCompilationErrors.ambiguousLateralColumnAliasError(nameParts, n)
case n if n == 1 && aliases.head.alias.resolved =>
resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o)
case _ => o
Expand All @@ -1853,8 +1853,8 @@ class Analyzer(override val catalogManager: CatalogManager)
plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) {
case p @ Project(projectList, _) if p.childrenResolved
&& !ResolveReferences.containsStar(projectList)
&& projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) =>
&& !ResolveReferences.containsStar(projectList)
&& projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) =>
var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]())
val newProjectList = projectList.zipWithIndex.map {
case (a: Alias, idx) =>
Expand All @@ -1869,6 +1869,30 @@ class Analyzer(override val catalogManager: CatalogManager)
wrapLCARef(e, p, aliasMap)
}
p.copy(projectList = newProjectList)

// Implementation notes:
// In Aggregate, introducing and wrapping this resolved leaf expression
// LateralColumnAliasReference is especially needed because it needs an accurate condition
// to trigger adding a Project above and extracting and pushing down aggregate functions
// or grouping expressions. Such operation can only be done once. With this
// LateralColumnAliasReference, that condition can simply be when the whole Aggregate is
// resolved. Otherwise, it can't tell if all aggregate functions are created and
// resolved so that it can start the extraction, because the lateral alias reference is
// unresolved and can be the argument to functions, blocking the resolution of functions.
case agg @ Aggregate(_, aggExprs, _) if agg.childrenResolved
&& !ResolveReferences.containsStar(aggExprs)
&& aggExprs.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) =>

var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]())
val newAggExprs = aggExprs.zipWithIndex.map {
case (a: Alias, idx) =>
val lcaWrapped = wrapLCARef(a, agg, aliasMap).asInstanceOf[Alias]
aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap)
lcaWrapped
case (e, _) =>
wrapLCARef(e, agg, aliasMap)
}
agg.copy(aggregateExpressions = newAggExprs)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, Expression, LateralColumnAliasReference, LeafExpression, Literal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf

/**
Expand All @@ -31,30 +34,54 @@ import org.apache.spark.sql.internal.SQLConf
* Plan-wise, it handles two types of operators: Project and Aggregate.
* - in Project, pushing down the referenced lateral alias into a newly created Project, resolve
* the attributes referencing these aliases
* - in Aggregate TODO.
* - in Aggregate, inserting the Project node above and falling back to the resolution of Project.
*
* The whole process is generally divided into two phases:
* 1) recognize resolved lateral alias, wrap the attributes referencing them with
* [[LateralColumnAliasReference]]
* 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]].
* For Project, it further resolves the attributes and push down the referenced lateral aliases.
* For Aggregate, TODO
* 2) when the whole operator is resolved,
* For Project, it unwrap [[LateralColumnAliasReference]], further resolves the attributes and
* push down the referenced lateral aliases.
* For Aggregate, it goes through the whole aggregation list, extracts the aggregation
* expressions and grouping expressions to keep them in this Aggregate node, and add a Project
* above with the original output. It doesn't do anything on [[LateralColumnAliasReference]], but
* completely leave it to the Project in the future turns of this rule.
*
* Example for Project:
* ** Example for Project:
* Before rewrite:
* Project [age AS a, 'a + 1]
* +- Child
*
* After phase 1:
* Project [age AS a, lateralalias(a) + 1]
* Project [age AS a, lca(a) + 1]
* +- Child
*
* After phase 2:
* Project [a, a + 1]
* +- Project [child output, age AS a]
* +- Child
*
* Example for Aggregate TODO
* ** Example for Aggregate:
* Before rewrite:
* Aggregate [dept#14] [dept#14 AS a#12, 'a + 1, avg(salary#16) AS b#13, 'b + avg(bonus#17)]
* +- Child [dept#14,name#15,salary#16,bonus#17]
*
* After phase 1:
* Aggregate [dept#14] [dept#14 AS a#12, lca(a) + 1, avg(salary#16) AS b#13, lca(b) + avg(bonus#17)]
* +- Child [dept#14,name#15,salary#16,bonus#17]
*
* After phase 2:
* Project [dept#14 AS a#12, lca(a) + 1, avg(salary)#26 AS b#13, lca(b) + avg(bonus)#27]
* +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27,dept#14]
* +- Child [dept#14,name#15,salary#16,bonus#17]
*
* Now the problem falls back to the lateral alias resolution in Project.
* After future rounds of this rule:
* Project [a#12, a#12 + 1, b#13, b#13 + avg(bonus)#27]
* +- Project [dept#14 AS a#12, avg(salary)#26 AS b#13]
* +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27,
* dept#14]
* +- Child [dept#14,name#15,salary#16,bonus#17]
*
*
* The name resolution priority:
Expand All @@ -75,6 +102,13 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
*/
val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr")

private def assignAlias(expr: Expression): NamedExpression = {
expr match {
case ne: NamedExpression => ne
case e => Alias(e, toPrettySQL(e))()
}
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
plan
Expand Down Expand Up @@ -129,6 +163,61 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
child = Project(innerProjectList.toSeq, child)
)
}

case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) if agg.resolved
&& aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>

// Check if current Aggregate is eligible to lift up with Project: the aggregate
// expression only contains: 1) aggregate functions, 2) grouping expressions, 3) lateral
// column alias reference or 4) literals.
// This check is to prevent unnecessary transformation on invalid plan, to guarantee it
// throws the same exception. For example, cases like non-aggregate expressions not
// in group by, once transformed, will throw a different exception: missing input.
def eligibleToLiftUp(exp: Expression): Boolean = {
exp match {
case e if AggregateExpression.isAggregate(e) => true
case e if groupingExpressions.exists(_.semanticEquals(e)) => true
case _: Literal | _: LateralColumnAliasReference => true
case s: ScalarSubquery if s.children.nonEmpty
&& !groupingExpressions.exists(_.semanticEquals(s)) => false
case _: LeafExpression => false
case e => e.children.forall(eligibleToLiftUp)
}
}
if (!aggregateExpressions.forall(eligibleToLiftUp)) {
return agg
}

val newAggExprs = collection.mutable.Set.empty[NamedExpression]
val expressionMap = collection.mutable.LinkedHashMap.empty[Expression, NamedExpression]
val projectExprs = aggregateExpressions.map { exp =>
exp.transformDown {
case aggExpr: AggregateExpression =>
// Doesn't support referencing a lateral alias in aggregate function
if (aggExpr.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
aggExpr.collectFirst {
case lcaRef: LateralColumnAliasReference =>
throw QueryCompilationErrors.lateralColumnAliasInAggFuncUnsupportedError(
lcaRef.nameParts, aggExpr)
}
}
val ne = expressionMap.getOrElseUpdate(aggExpr.canonicalized, assignAlias(aggExpr))
newAggExprs += ne
ne.toAttribute
case e if groupingExpressions.exists(_.semanticEquals(e)) =>
val ne = expressionMap.getOrElseUpdate(e.canonicalized, assignAlias(e))
newAggExprs += ne
ne.toAttribute
}.asInstanceOf[NamedExpression]
}
if (newAggExprs.isEmpty) {
agg
} else {
Project(
projectList = projectExprs,
child = agg.copy(aggregateExpressions = newAggExprs.toSeq)
)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3395,7 +3395,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
}
}

def ambiguousLateralColumnAlias(name: String, numOfMatches: Int): Throwable = {
def ambiguousLateralColumnAliasError(name: String, numOfMatches: Int): Throwable = {
new AnalysisException(
errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS",
messageParameters = Map(
Expand All @@ -3404,7 +3404,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
)
)
}
def ambiguousLateralColumnAlias(nameParts: Seq[String], numOfMatches: Int): Throwable = {
def ambiguousLateralColumnAliasError(nameParts: Seq[String], numOfMatches: Int): Throwable = {
new AnalysisException(
errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS",
messageParameters = Map(
Expand All @@ -3413,4 +3413,15 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
)
)
}

def lateralColumnAliasInAggFuncUnsupportedError(
lcaNameParts: Seq[String], aggExpr: Expression): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC",
messageParameters = Map(
"lca" -> toSQLId(lcaNameParts),
"aggFunc" -> toSQLExpr(aggExpr)
)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4045,7 +4045,7 @@ object SQLConf {
"higher resolution priority than the lateral column alias.")
.version("3.4.0")
.booleanConf
.createWithDefault(false)
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
Expand Down
Loading

0 comments on commit fd6d226

Please sign in to comment.