Skip to content

Commit

Permalink
[SPARK-46085][CONNECT] Dataset.groupingSets in Scala Spark Connect cl…
Browse files Browse the repository at this point in the history
…ient

### What changes were proposed in this pull request?

This PR proposes to add `Dataset.groupingsets` API added from #43813 to Scala Spark Connect cleint.

### Why are the changes needed?

For feature parity.

### Does this PR introduce _any_ user-facing change?

Yes, it adds a new API to Scala Spark Connect client.

### How was this patch tested?

Unittest was added.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #43995 from HyukjinKwon/SPARK-46085.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
HyukjinKwon authored and dongjoon-hyun committed Nov 25, 2023
1 parent a694a8a commit 5211f6b
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,41 @@ class Dataset[T] private[sql] (
proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
}

/**
* Create multi-dimensional aggregation for the current Dataset using the specified grouping
* sets, so we can run aggregation on them. See [[RelationalGroupedDataset]] for all the
* available aggregate functions.
*
* {{{
* // Compute the average for all numeric columns group by specific grouping sets.
* ds.groupingSets(Seq(Seq($"department", $"group"), Seq()), $"department", $"group").avg()
*
* // Compute the max age and average salary, group by specific grouping sets.
* ds.groupingSets(Seq($"department", $"gender"), Seq()), $"department", $"group").agg(Map(
* "salary" -> "avg",
* "age" -> "max"
* ))
* }}}
*
* @group untypedrel
* @since 4.0.0
*/
@scala.annotation.varargs
def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RelationalGroupedDataset = {
val groupingSetMsgs = groupingSets.map { groupingSet =>
val groupingSetMsg = proto.Aggregate.GroupingSets.newBuilder()
for (groupCol <- groupingSet) {
groupingSetMsg.addGroupingSet(groupCol.expr)
}
groupingSetMsg.build()
}
new RelationalGroupedDataset(
toDF(),
cols,
proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS,
groupingSets = Some(groupingSetMsgs))
}

/**
* (Scala-specific) Aggregates on the entire Dataset without groups.
* {{{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class RelationalGroupedDataset private[sql] (
private[sql] val df: DataFrame,
private[sql] val groupingExprs: Seq[Column],
groupType: proto.Aggregate.GroupType,
pivot: Option[proto.Aggregate.Pivot] = None) {
pivot: Option[proto.Aggregate.Pivot] = None,
groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) {

private[this] def toDF(aggExprs: Seq[Column]): DataFrame = {
df.sparkSession.newDataFrame { builder =>
Expand All @@ -60,6 +61,11 @@ class RelationalGroupedDataset private[sql] (
builder.getAggregateBuilder
.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT)
.setPivot(pivot.get)
case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS =>
assert(groupingSets.isDefined)
val aggBuilder = builder.getAggregateBuilder
.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS)
groupingSets.get.foreach(aggBuilder.addGroupingSets)
case g => throw new UnsupportedOperationException(g.toString)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3017,6 +3017,12 @@ class PlanGenerationTestSuite
simple.groupBy(Column("id")).pivot("a").agg(functions.count(Column("b")))
}

test("groupingSets") {
simple
.groupingSets(Seq(Seq(fn.col("a")), Seq.empty[Column]), fn.col("a"))
.agg("a" -> "max", "a" -> "count")
}

test("width_bucket") {
simple.select(fn.width_bucket(fn.col("b"), fn.col("b"), fn.col("b"), fn.col("a")))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Aggregate [a#0, spark_grouping_id#0L], [a#0, max(a#0) AS max(a)#0, count(a#0) AS count(a)#0L]
+- Expand [[id#0L, a#0, b#0, a#0, 0], [id#0L, a#0, b#0, null, 1]], [id#0L, a#0, b#0, a#0, spark_grouping_id#0L]
+- Project [id#0L, a#0, b#0, a#0 AS a#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"common": {
"planId": "1"
},
"aggregate": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
"groupType": "GROUP_TYPE_GROUPING_SETS",
"groupingExpressions": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}],
"aggregateExpressions": [{
"unresolvedFunction": {
"functionName": "max",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a",
"planId": "0"
}
}]
}
}, {
"unresolvedFunction": {
"functionName": "count",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a",
"planId": "0"
}
}]
}
}],
"groupingSets": [{
"groupingSet": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}]
}, {
}]
}
}
Binary file not shown.

0 comments on commit 5211f6b

Please sign in to comment.