Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44549][SQL] Support window functions in correlated scalar subqueries #42383

Closed
wants to merge 15 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
failOnInvalidOuterReference(a)
checkPlan(a.child, aggregated = true, canContainOuter)

// Same as Aggregate above.
case w: Window =>
failOnInvalidOuterReference(w)
checkPlan(w.child, aggregated = true, canContainOuter)

// Distinct does not host any correlated expressions, but during the optimization phase
// it will be rewritten as Aggregate, which can only be on a correlation path if the
// correlation contains only the supported correlated equality predicates.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ object DecorrelateInnerQuery extends PredicateHelper {
// parentOuterReferences: a set of parent outer references. As we recurse down we collect the
// set of outer references that are part of the Domain, and use it to construct the DomainJoins
// and join conditions.
// aggregated: a boolean flag indicating whether the result of the plan will be aggregated.
// aggregated: a boolean flag indicating whether the result of the plan will be aggregated
// (or used as an input for a window function)
// underSetOp: a boolean flag indicating whether a set operator (e.g. UNION) is a parent of the
// inner plan.
//
Expand Down Expand Up @@ -654,6 +655,25 @@ object DecorrelateInnerQuery extends PredicateHelper {
val newProject = Project(newProjectList ++ referencesToAdd, newChild)
(newProject, joinCond, outerReferenceMap)

case w @ Window(projectList, partitionSpec, orderSpec, child) =>
val outerReferences = collectOuterReferences(w.expressions)
assert(outerReferences.isEmpty, s"Correlated column is not allowed in window " +
s"function: $w")
val newOuterReferences = parentOuterReferences ++ outerReferences
val (newChild, joinCond, outerReferenceMap) =
decorrelate(child, newOuterReferences, aggregated = true, underSetOp)
agubichev marked this conversation as resolved.
Show resolved Hide resolved
// For now these are no-op, as we don't allow correlated references in the window
// function itself.
val newProjectList = replaceOuterReferences(projectList, outerReferenceMap)
val newPartitionSpec = replaceOuterReferences(partitionSpec, outerReferenceMap)
val newOrderSpec = replaceOuterReferences(orderSpec, outerReferenceMap)
agubichev marked this conversation as resolved.
Show resolved Hide resolved
val referencesToAdd = missingReferences(newProjectList, joinCond)

val newWindow = Window(newProjectList ++ referencesToAdd,
partitionSpec = newPartitionSpec ++ referencesToAdd,
orderSpec = newOrderSpec, newChild)
(newWindow, joinCond, outerReferenceMap)

case a @ Aggregate(groupingExpressions, aggregateExpressions, child) =>
val outerReferences = collectOuterReferences(a.expressions)
val newOuterReferences = parentOuterReferences ++ outerReferences
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,4 +581,48 @@ class DecorrelateInnerQuerySuite extends PlanTest {
Project(Seq(a4, b4), testRelation4)))))
check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
}

test("window function with correlated equality predicate") {
val outerPlan = testRelation2
val innerPlan =
Window(Seq(b, c),
partitionSpec = Seq(c), orderSpec = b.asc :: Nil,
Filter(And(OuterReference(x) === a, b === 3),
testRelation))
// Both the project list and the partition spec have added the correlated variable.
val correctAnswer =
Window(Seq(b, c, a), partitionSpec = Seq(c, a), orderSpec = b.asc :: Nil,
Filter(b === 3,
testRelation))
check(innerPlan, outerPlan, correctAnswer, Seq(x === a))
}

test("window function with correlated non-equality predicate") {
val outerPlan = testRelation2
val innerPlan =
Window(Seq(b, c),
partitionSpec = Seq(c), orderSpec = b.asc :: Nil,
Filter(And(OuterReference(x) > a, b === 3),
testRelation))
// Both the project list and the partition spec have added the correlated variable.
// The input to the filter is a domain join that produces 'x' values.
val correctAnswer =
Window(Seq(b, c, x), partitionSpec = Seq(c, x), orderSpec = b.asc :: Nil,
Filter(And(b === 3, x > a),
DomainJoin(Seq(x), testRelation)))
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x))
}

test("window function with correlated columns inside") {
val outerPlan = testRelation2
val innerPlan =
Window(Seq(b, c),
partitionSpec = Seq(c, OuterReference(x)), orderSpec = b.asc :: Nil,
Filter(b === 3,
testRelation))
val e = intercept[java.lang.AssertionError] {
DecorrelateInnerQuery(innerPlan, outerPlan.select())
}
assert(e.getMessage.contains("Correlated column is not allowed in"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1450,21 +1450,22 @@ SELECT * FROM t1 JOIN LATERAL
FROM t2
WHERE t2.c1 >= t1.c1)
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED",
"sqlState" : "0A000",
"messageParameters" : {
"treeNode" : "Filter (c1#x >= outer(c1#x))\n+- SubqueryAlias spark_catalog.default.t2\n +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])\n +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]\n +- LocalRelation [col1#x, col2#x]\n"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 34,
"stopIndex" : 108,
"fragment" : "SELECT sum(t2.c2) over (order by t2.c1)\n FROM t2\n WHERE t2.c1 >= t1.c1"
} ]
}
Project [c1#x, c2#x, sum(c2) OVER (ORDER BY c1 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL]
+- LateralJoin lateral-subquery#x [c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Project [sum(c2) OVER (ORDER BY c1 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL]
: +- Project [c2#x, c1#x, sum(c2) OVER (ORDER BY c1 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL, sum(c2) OVER (ORDER BY c1 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL]
: +- Window [sum(c2#x) windowspecdefinition(c1#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum(c2) OVER (ORDER BY c1 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL], [c1#x ASC NULLS FIRST]
: +- Project [c2#x, c1#x]
: +- Filter (c1#x >= outer(c1#x))
: +- SubqueryAlias spark_catalog.default.t2
: +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
Expand Down Expand Up @@ -2007,21 +2008,29 @@ SELECT * FROM t1 JOIN LATERAL
SELECT t4.c2
FROM t4)
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED",
"sqlState" : "0A000",
"messageParameters" : {
"treeNode" : "Filter (c1#x >= outer(c1#x))\n+- SubqueryAlias spark_catalog.default.t2\n +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])\n +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]\n +- LocalRelation [col1#x, col2#x]\n"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 34,
"stopIndex" : 108,
"fragment" : "SELECT sum(t2.c2) over (order by t2.c1)\n FROM t2\n WHERE t2.c1 >= t1.c1"
} ]
}
Project [c1#x, c2#x, sum(c2) OVER (ORDER BY c1 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL]
+- LateralJoin lateral-subquery#x [c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Union false, false
: :- Project [sum(c2) OVER (ORDER BY c1 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL]
: : +- Project [c2#x, c1#x, sum(c2) OVER (ORDER BY c1 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL, sum(c2) OVER (ORDER BY c1 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL]
: : +- Window [sum(c2#x) windowspecdefinition(c1#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS sum(c2) OVER (ORDER BY c1 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#xL], [c1#x ASC NULLS FIRST]
: : +- Project [c2#x, c1#x]
: : +- Filter (c1#x >= outer(c1#x))
: : +- SubqueryAlias spark_catalog.default.t2
: : +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- Project [cast(c2#x as bigint) AS c2#xL]
: +- Project [c2#x]
: +- SubqueryAlias spark_catalog.default.t4
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,27 @@ Project [emp_name#x, bonus_amt#x]
+- Project [emp_name#x, bonus_amt#x]
+- SubqueryAlias BONUS
+- LocalRelation [emp_name#x, bonus_amt#x]


-- !query
SELECT *
FROM BONUS
WHERE EXISTS(SELECT RANK() OVER (PARTITION BY hiredate ORDER BY salary) AS s
FROM EMP, DEPT where EMP.dept_id = DEPT.dept_id
AND DEPT.dept_name < BONUS.emp_name)
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE",
"sqlState" : "0A000",
"messageParameters" : {
"treeNode" : "(dept_name#x < outer(emp_name#x))\nFilter ((dept_id#x = dept_id#x) AND (dept_name#x < outer(emp_name#x)))\n+- Join Inner\n :- SubqueryAlias emp\n : +- View (`EMP`, [id#x,emp_name#x,hiredate#x,salary#x,dept_id#x])\n : +- Project [cast(id#x as int) AS id#x, cast(emp_name#x as string) AS emp_name#x, cast(hiredate#x as date) AS hiredate#x, cast(salary#x as double) AS salary#x, cast(dept_id#x as int) AS dept_id#x]\n : +- Project [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]\n : +- SubqueryAlias EMP\n : +- LocalRelation [id#x, emp_name#x, hiredate#x, salary#x, dept_id#x]\n +- SubqueryAlias dept\n +- View (`DEPT`, [dept_id#x,dept_name#x,state#x])\n +- Project [cast(dept_id#x as int) AS dept_id#x, cast(dept_name#x as string) AS dept_name#x, cast(state#x as string) AS state#x]\n +- Project [dept_id#x, dept_name#x, state#x]\n +- SubqueryAlias DEPT\n +- LocalRelation [dept_id#x, dept_name#x, state#x]\n"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 34,
"stopIndex" : 224,
"fragment" : "SELECT RANK() OVER (PARTITION BY hiredate ORDER BY salary) AS s\n FROM EMP, DEPT where EMP.dept_id = DEPT.dept_id\n AND DEPT.dept_name < BONUS.emp_name"
} ]
}
Original file line number Diff line number Diff line change
Expand Up @@ -639,3 +639,26 @@ Filter isnotnull(min(t1b)#x)
+- Project [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x, t1f#x, t1g#x, t1h#x, t1i#x]
+- SubqueryAlias t1
+- LocalRelation [t1a#x, t1b#x, t1c#x, t1d#xL, t1e#x, t1f#x, t1g#x, t1h#x, t1i#x]


-- !query
select t1a
from t1
where t1f IN (SELECT RANK() OVER (partition by t3c order by t2b) as s
FROM t2, t3 where t2.t2c = t3.t3c and t2.t2a < t1.t1a)
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE",
"sqlState" : "0A000",
"messageParameters" : {
"treeNode" : "(t2a#x < outer(t1a#x))\nFilter ((t2c#x = t3c#x) AND (t2a#x < outer(t1a#x)))\n+- Join Inner\n :- SubqueryAlias t2\n : +- View (`t2`, [t2a#x,t2b#x,t2c#x,t2d#xL,t2e#x,t2f#x,t2g#x,t2h#x,t2i#x])\n : +- Project [cast(t2a#x as string) AS t2a#x, cast(t2b#x as smallint) AS t2b#x, cast(t2c#x as int) AS t2c#x, cast(t2d#xL as bigint) AS t2d#xL, cast(t2e#x as float) AS t2e#x, cast(t2f#x as double) AS t2f#x, cast(t2g#x as decimal(4,0)) AS t2g#x, cast(t2h#x as timestamp) AS t2h#x, cast(t2i#x as date) AS t2i#x]\n : +- Project [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, t2h#x, t2i#x]\n : +- SubqueryAlias t2\n : +- LocalRelation [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, t2h#x, t2i#x]\n +- SubqueryAlias t3\n +- View (`t3`, [t3a#x,t3b#x,t3c#x,t3d#xL,t3e#x,t3f#x,t3g#x,t3h#x,t3i#x])\n +- Project [cast(t3a#x as string) AS t3a#x, cast(t3b#x as smallint) AS t3b#x, cast(t3c#x as int) AS t3c#x, cast(t3d#xL as bigint) AS t3d#xL, cast(t3e#x as float) AS t3e#x, cast(t3f#x as double) AS t3f#x, cast(t3g#x as decimal(4,0)) AS t3g#x, cast(t3h#x as timestamp) AS t3h#x, cast(t3i#x as date) AS t3i#x]\n +- Project [t3a#x, t3b#x, t3c#x, t3d#xL, t3e#x, t3f#x, t3g#x, t3h#x, t3i#x]\n +- SubqueryAlias t3\n +- LocalRelation [t3a#x, t3b#x, t3c#x, t3d#xL, t3e#x, t3f#x, t3g#x, t3h#x, t3i#x]\n"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 34,
"stopIndex" : 172,
"fragment" : "SELECT RANK() OVER (partition by t3c order by t2b) as s\n FROM t2, t3 where t2.t2c = t3.t3c and t2.t2a < t1.t1a"
} ]
}
Loading