Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNOW-1754978]: Add support for first and last in GroupBy.agg #2847

Merged
merged 14 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
- %X: Locale’s appropriate time representation.
- %%: A literal '%' character.
- Added support for `Series.between`.
- Added support for `first` and `last` in `DataFrameGroupBy.apply` and `SeriesGroupBy.apply`.
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved

#### Bug Fixes

Expand Down
4 changes: 4 additions & 0 deletions docs/source/modin/supported/agg_supp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ the aggregation is supported by ``SeriesGroupBy.agg``.
+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+
| ``len`` | ``N`` | ``N`` | ``Y`` | ``Y`` |
+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+
| ``first`` | ``N`` | ``N`` | ``Y`` | ``Y`` |
+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+
| ``last`` | ``N`` | ``N`` | ``Y`` | ``Y`` |
+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+
171 changes: 158 additions & 13 deletions src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,70 @@ def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable:
return sum(builtin("zeroifnull")(col) for col in cols)


def _column_first_value(
column: SnowparkColumn,
row_position_snowflake_quoted_identifier: str,
ignore_nulls: bool,
) -> SnowparkColumn:
"""
Returns the first value (ordered by `row_position_snowflake_identifier`) over the specified group.

Parameters
----------
col: Snowpark Column
The Snowpark column to aggregate.
row_position_snowflake_quoted_identifier: str
The Snowflake quoted identifier of the column to order by.
ignore_nulls: bool
Whether or not to ignore nulls.

Returns
-------
The aggregated Snowpark Column.
"""
if ignore_nulls:
col_to_min_by = iff(
col(column).is_null(),
pandas_lit(None),
row_position_snowflake_quoted_identifier,
)
else:
col_to_min_by = col(row_position_snowflake_quoted_identifier)
return builtin("min_by")(col(column), col_to_min_by)


def _column_last_value(
column: SnowparkColumn,
row_position_snowflake_quoted_identifier: str,
ignore_nulls: bool,
) -> SnowparkColumn:
"""
Returns the last value (ordered by `row_position_snowflake_identifier`) over the specified group.

Parameters
----------
col: Snowpark Column
The Snowpark column to aggregate.
row_position_snowflake_quoted_identifier: str
The Snowflake quoted identifier of the column to order by.
ignore_nulls: bool
Whether or not to ignore nulls.

Returns
-------
The aggregated Snowpark Column.
"""
if ignore_nulls:
col_to_max_by = iff(
col(column).is_null(),
pandas_lit(None),
row_position_snowflake_quoted_identifier,
)
else:
col_to_max_by = col(row_position_snowflake_quoted_identifier)
return builtin("max_by")(col(column), col_to_max_by)


def _create_pandas_to_snowpark_pandas_aggregation_map(
pandas_functions: Iterable[AggFuncTypeBase],
snowpark_pandas_aggregation: _SnowparkPandasAggregation,
Expand Down Expand Up @@ -469,6 +533,18 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
preserves_snowpark_pandas_types=False,
),
),
"first": _SnowparkPandasAggregation(
axis_0_aggregation=_column_first_value,
axis_1_aggregation_keepna=lambda *cols: cols[0],
axis_1_aggregation_skipna=lambda *cols: coalesce(*cols),
preserves_snowpark_pandas_types=True,
),
"last": _SnowparkPandasAggregation(
axis_0_aggregation=_column_last_value,
axis_1_aggregation_keepna=lambda *cols: cols[-1],
axis_1_aggregation_skipna=lambda *cols: coalesce(*(cols[::-1])),
preserves_snowpark_pandas_types=True,
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("mean", np.mean),
_SnowparkPandasAggregation(
Expand Down Expand Up @@ -610,7 +686,10 @@ def is_snowflake_agg_func(agg_func: AggFuncTypeBase) -> bool:


def get_snowflake_agg_func(
agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1]
agg_func: AggFuncTypeBase,
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> Optional[SnowflakeAggFunc]:
"""
Get the corresponding Snowflake/Snowpark aggregation function for the given aggregation function.
Expand Down Expand Up @@ -659,6 +738,23 @@ def get_snowflake_agg_func(
def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn:
return column_quantile(col, interpolation, q)

elif (
snowpark_aggregation == _column_first_value
or snowpark_aggregation == _column_last_value
):
if _is_df_agg:
# First and last are not supported for df.agg.
return None
ignore_nulls = agg_kwargs.get("skipna", True)
row_position_snowflake_quoted_identifier = agg_kwargs.get(
"_first_last_row_pos_col", None
)
snowpark_aggregation = functools.partial(
snowpark_aggregation,
ignore_nulls=ignore_nulls,
row_position_snowflake_quoted_identifier=row_position_snowflake_quoted_identifier,
)

assert (
snowpark_aggregation is not None
), "Internal error: Snowpark pandas should have identified a Snowpark aggregation."
Expand Down Expand Up @@ -707,7 +803,10 @@ def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn:


def _is_supported_snowflake_agg_func(
agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1]
agg_func: AggFuncTypeBase,
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> bool:
"""
check if the aggregation function is supported with snowflake. Current supported
Expand All @@ -724,11 +823,14 @@ def _is_supported_snowflake_agg_func(
# For named aggregations, like `df.agg(new_col=("old_col", "sum"))`,
# take the second part of the named aggregation.
agg_func = agg_func[0]
return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None
return get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is not None


def _are_all_agg_funcs_supported_by_snowflake(
agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: Literal[0, 1]
agg_funcs: list[AggFuncTypeBase],
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> bool:
"""
Check if all aggregation functions in the given list are snowflake supported
Expand All @@ -739,14 +841,16 @@ def _are_all_agg_funcs_supported_by_snowflake(
return False.
"""
return all(
_is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs
_is_supported_snowflake_agg_func(func, agg_kwargs, axis, _is_df_agg)
for func in agg_funcs
)


def check_is_aggregation_supported_in_snowflake(
agg_func: AggFuncType,
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> bool:
"""
check if distributed implementation with snowflake is available for the aggregation
Expand All @@ -756,6 +860,8 @@ def check_is_aggregation_supported_in_snowflake(
agg_func: the aggregation function to apply
agg_kwargs: keyword argument passed for the aggregation function, such as ddof, min_count etc.
The value can be different for different aggregation function.
_is_df_agg: whether or not this is being called by df.agg, since some functions are only supported
for groupby_agg.
Returns:
bool
Whether the aggregation operation can be executed with snowflake sql engine.
Expand All @@ -765,15 +871,21 @@ def check_is_aggregation_supported_in_snowflake(
if is_dict_like(agg_func):
return all(
(
_are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis)
_are_all_agg_funcs_supported_by_snowflake(
value, agg_kwargs, axis, _is_df_agg
)
if is_list_like(value) and not is_named_tuple(value)
else _is_supported_snowflake_agg_func(value, agg_kwargs, axis)
else _is_supported_snowflake_agg_func(
value, agg_kwargs, axis, _is_df_agg
)
)
for value in agg_func.values()
)
elif is_list_like(agg_func):
return _are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis)
return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis)
return _are_all_agg_funcs_supported_by_snowflake(
agg_func, agg_kwargs, axis, _is_df_agg
)
return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg)


def _is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool:
Expand Down Expand Up @@ -1372,10 +1484,16 @@ def repr_aggregate_function(agg_func: AggFuncType, agg_kwargs: Mapping) -> str:
if using_named_aggregations_for_func(agg_func):
# New axis labels are sensitive, so replace them with "new_label."
# Existing axis labels are sensitive, so replace them with "label."
return ", ".join(
f"new_label=(label, {repr_aggregate_function(f, agg_kwargs)})"
for _, f in agg_kwargs.values()
)
if is_list_like(list(agg_kwargs.values())[0]):
sfc-gh-joshi marked this conversation as resolved.
Show resolved Hide resolved
return ", ".join(
f"new_label=(label, {repr_aggregate_function(f, agg_kwargs)})"
for _, f in agg_kwargs.values()
)
else:
return ", ".join(
f"new_label=(label, {repr_aggregate_function(f, agg_kwargs)})"
for f in agg_kwargs.values()
)
if isinstance(agg_func, str):
# Strings functions represent names of pandas functions, e.g.
# "sum" means to aggregate with pandas.Series.sum. string function
Expand Down Expand Up @@ -1422,3 +1540,30 @@ def repr_aggregate_function(agg_func: AggFuncType, agg_kwargs: Mapping) -> str:
# exposing sensitive user input in the NotImplemented error message and
# thus in telemetry.
return "Callable"


def is_first_last_in_agg_funcs(
column_to_agg_func: dict[str, Union[list[AggFuncInfo], AggFuncInfo]]
) -> bool:
"""
Helper function to check if the `first` or `last` aggregation functions have been specified.

Parameters
----------
column_to_agg_func: dict[str, Union[list[AggFuncInfo], AggFuncInfo]]
The mapping of column name to aggregation function (or functions) to apply.

Returns
-------
bool
Whether any of the functions to apply are either `first` or `last`.
"""

def _is_first_last_agg_func(value: AggFuncInfo) -> bool:
return value.func in ["first", "last"]

return any(
(isinstance(val, AggFuncInfo) and _is_first_last_agg_func(val))
or (isinstance(val, list) and any(_is_first_last_agg_func(v) for v in val))
for val in column_to_agg_func.values()
)
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
get_agg_func_to_col_map,
get_pandas_aggr_func_name,
get_snowflake_agg_func,
is_first_last_in_agg_funcs,
repr_aggregate_function,
using_named_aggregations_for_func,
)
Expand Down Expand Up @@ -3796,9 +3797,20 @@ def convert_func_to_agg_func_info(
internal_frame.index_column_snowflake_quoted_identifiers
)

# We need to check if `first` or `last` are in the aggregation functions,
# as we need to ensure a row position column and pass it in as an agg_kwarg
# if it is (for the min_by/max_by function).
first_last_present = is_first_last_in_agg_funcs(column_to_agg_func)
if first_last_present:
internal_frame = internal_frame.ensure_row_position_column()
agg_kwargs[
"_first_last_row_pos_col"
] = internal_frame.row_position_snowflake_quoted_identifier
agg_col_ops, new_data_column_index_names = generate_column_agg_info(
internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby
)
if first_last_present:
agg_kwargs.pop("_first_last_row_pos_col")
# the pandas label and quoted identifier generated for each result column
# after aggregation will be used as new pandas label and quoted identifiers.
new_data_column_pandas_labels = []
Expand Down Expand Up @@ -6026,7 +6038,9 @@ def agg(
# by snowflake engine.
# If we are using Named Aggregations, we need to do our supported check slightly differently.
uses_named_aggs = using_named_aggregations_for_func(func)
if not check_is_aggregation_supported_in_snowflake(func, kwargs, axis):
if not check_is_aggregation_supported_in_snowflake(
func, kwargs, axis, _is_df_agg=True
):
ErrorMessage.not_implemented(
f"Snowpark pandas aggregate does not yet support the aggregation {repr_aggregate_function(func, kwargs)} with the given arguments."
)
Expand Down
27 changes: 26 additions & 1 deletion tests/integ/modin/frame/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,12 +1141,37 @@ def test_named_agg_not_supported_axis_1(numeric_native_df):
),
param(
None,
{"x": ("A", np.exp), "y": pd.NamedAgg("C", sum)},
{
"x": ("A", np.exp),
"y": pd.NamedAgg("C", sum),
},
"Snowpark pandas aggregate does not yet support the aggregation "
+ "new_label=\\(label, np\\.exp\\), new_label=\\(label, <built-in function sum>\\)"
+ " with the given arguments",
id="named_agg",
),
param(
None,
{
"x": ("A", "first"),
"y": pd.NamedAgg("C", sum),
},
"Snowpark pandas aggregate does not yet support the aggregation "
+ "new_label=\\(label, 'first'\\), new_label=\\(label, <built-in function sum>\\)"
+ " with the given arguments",
id="named_agg",
),
param(
None,
{
"x": ("A", "last"),
"y": pd.NamedAgg("C", sum),
},
"Snowpark pandas aggregate does not yet support the aggregation "
+ "new_label=\\(label, 'last'\\), new_label=\\(label, <built-in function sum>\\)"
+ " with the given arguments",
id="named_agg",
),
],
)
@sql_count_checker(query_count=0)
Expand Down
32 changes: 32 additions & 0 deletions tests/integ/modin/groupby/test_groupby_basic_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,38 @@ def test_valid_func_valid_kwarg_should_work(basic_snowpark_pandas_df):
)


@pytest.mark.parametrize("skipna", [True, False])
@sql_count_checker(query_count=1)
def test_groupby_agg_first(skipna):
native_df = native_pd.DataFrame(
{"grp_col": ["A", "A", "B", "B", "A"], "float_col": [np.nan, 2, 3, np.nan, 4]}
)
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: df.groupby(by="grp_col").agg(
{"float_col": ["quantile", "first"]}, skipna=skipna
),
)


@pytest.mark.parametrize("skipna", [True, False])
@sql_count_checker(query_count=1)
def test_groupby_agg_last(skipna):
native_df = native_pd.DataFrame(
{"grp_col": ["A", "A", "B", "B", "A"], "float_col": [np.nan, 2, 3, np.nan, 4]}
)
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: df.groupby(by="grp_col").agg(
{"float_col": ["quantile", "last"]}, skipna=skipna
),
)


class TestTimedelta:
@sql_count_checker(query_count=1)
@pytest.mark.parametrize(
Expand Down
Loading