From fe7cabad9a4ddc45f516c1bdb9f694fb065438d2 Mon Sep 17 00:00:00 2001 From: Chengjie Li <109656400+ChengjieLi28@users.noreply.github.com> Date: Wed, 20 Sep 2023 17:19:40 +0800 Subject: [PATCH] BUG: Column pruning failed when `groupby` by multi series (#708) Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- .../tileable/column_pruning/input_column_selector.py | 5 +++-- .../tests/test_input_column_selector.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/xorbits/_mars/optimization/logical/tileable/column_pruning/input_column_selector.py b/python/xorbits/_mars/optimization/logical/tileable/column_pruning/input_column_selector.py index dda97f7f0..97e866e95 100644 --- a/python/xorbits/_mars/optimization/logical/tileable/column_pruning/input_column_selector.py +++ b/python/xorbits/_mars/optimization/logical/tileable/column_pruning/input_column_selector.py @@ -164,9 +164,10 @@ def df_groupby_agg_select_function( ret = {} # group by a series groupby_series = False - if isinstance(by, list) and len(by) == 1 and isinstance(by[0], BaseSeriesData): + if isinstance(by, list) and all([isinstance(_by, BaseSeriesData) for _by in by]): groupby_series = True - ret[by[0]] = {by[0].name} + for _by in by: + ret[_by] = {_by.name} if isinstance(inp, BaseSeriesData): ret[inp] = {inp.name} diff --git a/python/xorbits/_mars/optimization/logical/tileable/column_pruning/tests/test_input_column_selector.py b/python/xorbits/_mars/optimization/logical/tileable/column_pruning/tests/test_input_column_selector.py index e351e339e..82cf899e8 100644 --- a/python/xorbits/_mars/optimization/logical/tileable/column_pruning/tests/test_input_column_selector.py +++ b/python/xorbits/_mars/optimization/logical/tileable/column_pruning/tests/test_input_column_selector.py @@ -116,6 +116,18 @@ def test_df_groupby_agg(): assert labels.data in input_columns assert input_columns[labels.data] == {"label"} + label1 = Series([1, 1, 1, 1], name="label1") + label2 = Series([2, 2, 3, 3], name="label2") + s = df.groupby(by=[label1, label2]).sum() + input_columns = InputColumnSelector.select(s.data, {"foo"}) + assert len(input_columns) == 3 + assert df.data in input_columns + assert input_columns[df.data] == {"foo"} + assert label1.data in input_columns + assert input_columns[label1.data] == {"label1"} + assert label2.data in input_columns + assert input_columns[label2.data] == {"label2"} + @pytest.mark.skip(reason="group by index is not supported yet") def test_df_groupby_index_agg():