Skip to content

Commit

Permalink
feat: add support for rollup (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
richtia authored Sep 18, 2024
1 parent 2eb1d59 commit 768f9d1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
26 changes: 26 additions & 0 deletions src/gateway/converter/spark_to_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,10 @@ def handle_grouping_and_measures(self, rel: spark_relations_pb2.Aggregate,
# Generate and add all groupings required for CUBE
cube_groupings = self.create_cube_groupings(rel_grouping_expressions)
aggregate.groupings.extend(cube_groupings)
case spark_relations_pb2.Aggregate.GroupType.GROUP_TYPE_ROLLUP:
# Generate and add all groupings required for ROLLUP
rollup_groupings = self.create_rollup_groupings(rel_grouping_expressions)
aggregate.groupings.extend(rollup_groupings)
case _:
raise NotImplementedError(
"Only GROUPBY and CUBE group types are currently supported."
Expand Down Expand Up @@ -1356,6 +1360,28 @@ def create_cube_groupings(self, grouping_expressions):

return cube_groupings

def create_rollup_groupings(self, grouping_expressions):
"""Create all combinations of grouping expressions for rollup."""
num_expressions = len(grouping_expressions)
rollup_groupings = []

for i in range(num_expressions):
current_grouping = []
for j in range(i + 1):
converted_expression = self.convert_expression(grouping_expressions[j])
current_grouping.append(converted_expression)
rollup_groupings.append(
algebra_pb2.AggregateRel.Grouping(grouping_expressions=current_grouping)
)

# Add a final grouping with no expressions for the grand total.
# The grand total aggregates over all rows.
rollup_groupings.append(
algebra_pb2.AggregateRel.Grouping(grouping_expressions=[])
)

return rollup_groupings

# pylint: disable=too-many-locals,pointless-string-statement
def convert_show_string_relation(self, rel: spark_relations_pb2.ShowString) -> algebra_pb2.Rel:
"""Convert a show string relation into a Substrait subplan."""
Expand Down
2 changes: 0 additions & 2 deletions src/gateway/tests/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ def mark_dataframe_tests_as_xfail(request):
pytest.skip(reason="inf vs -inf difference")
if source == "gateway-over-duckdb" and originalname in ["test_union", "test_unionall"]:
pytest.skip(reason="distinct not handled properly")
if source == "gateway-over-datafusion" and originalname == "test_rollup":
pytest.skip(reason="rollup aggregation not yet implemented in gateway")
if source == "gateway-over-duckdb" and originalname == "test_rollup":
pytest.skip(reason="rollup aggregation not yet implemented in gateway")
if source == "gateway-over-duckdb" and originalname == "test_cube":
Expand Down

0 comments on commit 768f9d1

Please sign in to comment.