Skip to content

Commit

Permalink
[SNOW-1754978]: Add support for first and last in GroupBy.agg (#2847)
Browse files Browse the repository at this point in the history
<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.

Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1754978

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [ ] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development)

3. Please describe how your code solves the related issue.

Adds support for `first` and `last` in groupby.agg (when used with other
functions)

---------

Co-authored-by: Jonathan Shi <[email protected]>
Co-authored-by: Jonathan Shi <[email protected]>
  • Loading branch information
3 people authored Jan 16, 2025
1 parent 8b230dc commit eea9995
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
- Added support for `include_groups=False` in `DataFrameGroupBy.apply`.
- Added support for `expand=True` in `Series.str.split`.
- Added support for `DataFrame.pop` and `Series.pop`.
- Added support for `first` and `last` in `DataFrameGroupBy.agg` and `SeriesGroupBy.agg`.

#### Bug Fixes

Expand Down
4 changes: 4 additions & 0 deletions docs/source/modin/supported/agg_supp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ the aggregation is supported by ``SeriesGroupBy.agg``.
+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+
| ``len`` | ``N`` | ``N`` | ``Y`` | ``Y`` |
+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+
| ``first`` | ``N`` | ``N`` | ``Y`` | ``Y`` |
+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+
| ``last`` | ``N`` | ``N`` | ``Y`` | ``Y`` |
+-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+
176 changes: 163 additions & 13 deletions src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,70 @@ def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable:
return sum(builtin("zeroifnull")(col) for col in cols)


def _column_first_value(
column: SnowparkColumn,
row_position_snowflake_quoted_identifier: str,
ignore_nulls: bool,
) -> SnowparkColumn:
"""
Returns the first value (ordered by `row_position_snowflake_identifier`) over the specified group.
Parameters
----------
col: Snowpark Column
The Snowpark column to aggregate.
row_position_snowflake_quoted_identifier: str
The Snowflake quoted identifier of the column to order by.
ignore_nulls: bool
Whether or not to ignore nulls.
Returns
-------
The aggregated Snowpark Column.
"""
if ignore_nulls:
col_to_min_by = iff(
col(column).is_null(),
pandas_lit(None),
col(row_position_snowflake_quoted_identifier),
)
else:
col_to_min_by = col(row_position_snowflake_quoted_identifier)
return builtin("min_by")(col(column), col_to_min_by)


def _column_last_value(
column: SnowparkColumn,
row_position_snowflake_quoted_identifier: str,
ignore_nulls: bool,
) -> SnowparkColumn:
"""
Returns the last value (ordered by `row_position_snowflake_identifier`) over the specified group.
Parameters
----------
col: Snowpark Column
The Snowpark column to aggregate.
row_position_snowflake_quoted_identifier: str
The Snowflake quoted identifier of the column to order by.
ignore_nulls: bool
Whether or not to ignore nulls.
Returns
-------
The aggregated Snowpark Column.
"""
if ignore_nulls:
col_to_max_by = iff(
col(column).is_null(),
pandas_lit(None),
col(row_position_snowflake_quoted_identifier),
)
else:
col_to_max_by = col(row_position_snowflake_quoted_identifier)
return builtin("max_by")(col(column), col_to_max_by)


def _create_pandas_to_snowpark_pandas_aggregation_map(
pandas_functions: Iterable[AggFuncTypeBase],
snowpark_pandas_aggregation: _SnowparkPandasAggregation,
Expand Down Expand Up @@ -470,6 +534,18 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
preserves_snowpark_pandas_types=False,
),
),
"first": _SnowparkPandasAggregation(
axis_0_aggregation=_column_first_value,
axis_1_aggregation_keepna=lambda *cols: cols[0],
axis_1_aggregation_skipna=lambda *cols: coalesce(*cols),
preserves_snowpark_pandas_types=True,
),
"last": _SnowparkPandasAggregation(
axis_0_aggregation=_column_last_value,
axis_1_aggregation_keepna=lambda *cols: cols[-1],
axis_1_aggregation_skipna=lambda *cols: coalesce(*(cols[::-1])),
preserves_snowpark_pandas_types=True,
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("mean", np.mean),
_SnowparkPandasAggregation(
Expand Down Expand Up @@ -611,7 +687,10 @@ def is_snowflake_agg_func(agg_func: AggFuncTypeBase) -> bool:


def get_snowflake_agg_func(
agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1]
agg_func: AggFuncTypeBase,
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> Optional[SnowflakeAggFunc]:
"""
Get the corresponding Snowflake/Snowpark aggregation function for the given aggregation function.
Expand Down Expand Up @@ -660,6 +739,23 @@ def get_snowflake_agg_func(
def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn:
return column_quantile(col, interpolation, q)

elif (
snowpark_aggregation == _column_first_value
or snowpark_aggregation == _column_last_value
):
if _is_df_agg:
# First and last are not supported for df.agg.
return None
ignore_nulls = agg_kwargs.get("skipna", True)
row_position_snowflake_quoted_identifier = agg_kwargs.get(
"_first_last_row_pos_col", None
)
snowpark_aggregation = functools.partial(
snowpark_aggregation,
ignore_nulls=ignore_nulls,
row_position_snowflake_quoted_identifier=row_position_snowflake_quoted_identifier,
)

assert (
snowpark_aggregation is not None
), "Internal error: Snowpark pandas should have identified a Snowpark aggregation."
Expand Down Expand Up @@ -708,7 +804,10 @@ def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn:


def _is_supported_snowflake_agg_func(
agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1]
agg_func: AggFuncTypeBase,
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> bool:
"""
check if the aggregation function is supported with snowflake. Current supported
Expand All @@ -725,11 +824,14 @@ def _is_supported_snowflake_agg_func(
# For named aggregations, like `df.agg(new_col=("old_col", "sum"))`,
# take the second part of the named aggregation.
agg_func = agg_func[0]
return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None
return get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is not None


def _are_all_agg_funcs_supported_by_snowflake(
agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: Literal[0, 1]
agg_funcs: list[AggFuncTypeBase],
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> bool:
"""
Check if all aggregation functions in the given list are snowflake supported
Expand All @@ -740,14 +842,16 @@ def _are_all_agg_funcs_supported_by_snowflake(
return False.
"""
return all(
_is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs
_is_supported_snowflake_agg_func(func, agg_kwargs, axis, _is_df_agg)
for func in agg_funcs
)


def check_is_aggregation_supported_in_snowflake(
agg_func: AggFuncType,
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> bool:
"""
check if distributed implementation with snowflake is available for the aggregation
Expand All @@ -757,6 +861,8 @@ def check_is_aggregation_supported_in_snowflake(
agg_func: the aggregation function to apply
agg_kwargs: keyword argument passed for the aggregation function, such as ddof, min_count etc.
The value can be different for different aggregation function.
_is_df_agg: whether or not this is being called by df.agg, since some functions are only supported
for groupby_agg.
Returns:
bool
Whether the aggregation operation can be executed with snowflake sql engine.
Expand All @@ -766,15 +872,21 @@ def check_is_aggregation_supported_in_snowflake(
if is_dict_like(agg_func):
return all(
(
_are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis)
_are_all_agg_funcs_supported_by_snowflake(
value, agg_kwargs, axis, _is_df_agg
)
if is_list_like(value) and not is_named_tuple(value)
else _is_supported_snowflake_agg_func(value, agg_kwargs, axis)
else _is_supported_snowflake_agg_func(
value, agg_kwargs, axis, _is_df_agg
)
)
for value in agg_func.values()
)
elif is_list_like(agg_func):
return _are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis)
return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis)
return _are_all_agg_funcs_supported_by_snowflake(
agg_func, agg_kwargs, axis, _is_df_agg
)
return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg)


def _is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool:
Expand Down Expand Up @@ -1373,10 +1485,21 @@ def repr_aggregate_function(agg_func: AggFuncType, agg_kwargs: Mapping) -> str:
if using_named_aggregations_for_func(agg_func):
# New axis labels are sensitive, so replace them with "new_label."
# Existing axis labels are sensitive, so replace them with "label."
return ", ".join(
f"new_label=(label, {repr_aggregate_function(f, agg_kwargs)})"
for _, f in agg_kwargs.values()
)
# This is checking whether the named aggregations are for a DataFrame,
# in which case they are of the format new_col_name = (col_to_operate_on,
# function), or for a Series, in which case they are of the format
# new_col_name=function, in order to ensure we parse the functions out
# from the keyword args correctly.
if is_list_like(list(agg_kwargs.values())[0]):
return ", ".join(
f"new_label=(label, {repr_aggregate_function(f, agg_kwargs)})"
for _, f in agg_kwargs.values()
)
else:
return ", ".join(
f"new_label=(label, {repr_aggregate_function(f, agg_kwargs)})"
for f in agg_kwargs.values()
)
if isinstance(agg_func, str):
# Strings functions represent names of pandas functions, e.g.
# "sum" means to aggregate with pandas.Series.sum. string function
Expand Down Expand Up @@ -1423,3 +1546,30 @@ def repr_aggregate_function(agg_func: AggFuncType, agg_kwargs: Mapping) -> str:
# exposing sensitive user input in the NotImplemented error message and
# thus in telemetry.
return "Callable"


def is_first_last_in_agg_funcs(
column_to_agg_func: dict[str, Union[list[AggFuncInfo], AggFuncInfo]]
) -> bool:
"""
Helper function to check if the `first` or `last` aggregation functions have been specified.
Parameters
----------
column_to_agg_func: dict[str, Union[list[AggFuncInfo], AggFuncInfo]]
The mapping of column name to aggregation function (or functions) to apply.
Returns
-------
bool
Whether any of the functions to apply are either `first` or `last`.
"""

def _is_first_last_agg_func(value: AggFuncInfo) -> bool:
return value.func in ["first", "last"]

return any(
(isinstance(val, AggFuncInfo) and _is_first_last_agg_func(val))
or (isinstance(val, list) and any(_is_first_last_agg_func(v) for v in val))
for val in column_to_agg_func.values()
)
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
get_agg_func_to_col_map,
get_pandas_aggr_func_name,
get_snowflake_agg_func,
is_first_last_in_agg_funcs,
repr_aggregate_function,
using_named_aggregations_for_func,
)
Expand Down Expand Up @@ -3801,9 +3802,20 @@ def convert_func_to_agg_func_info(
internal_frame.index_column_snowflake_quoted_identifiers
)

# We need to check if `first` or `last` are in the aggregation functions,
# as we need to ensure a row position column and pass it in as an agg_kwarg
# if it is (for the min_by/max_by function).
first_last_present = is_first_last_in_agg_funcs(column_to_agg_func)
if first_last_present:
internal_frame = internal_frame.ensure_row_position_column()
agg_kwargs[
"_first_last_row_pos_col"
] = internal_frame.row_position_snowflake_quoted_identifier
agg_col_ops, new_data_column_index_names = generate_column_agg_info(
internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby
)
if first_last_present:
agg_kwargs.pop("_first_last_row_pos_col")
# the pandas label and quoted identifier generated for each result column
# after aggregation will be used as new pandas label and quoted identifiers.
new_data_column_pandas_labels = []
Expand Down Expand Up @@ -6048,7 +6060,9 @@ def agg(
# by snowflake engine.
# If we are using Named Aggregations, we need to do our supported check slightly differently.
uses_named_aggs = using_named_aggregations_for_func(func)
if not check_is_aggregation_supported_in_snowflake(func, kwargs, axis):
if not check_is_aggregation_supported_in_snowflake(
func, kwargs, axis, _is_df_agg=True
):
ErrorMessage.not_implemented(
f"Snowpark pandas aggregate does not yet support the aggregation {repr_aggregate_function(func, kwargs)} with the given arguments."
)
Expand Down
27 changes: 26 additions & 1 deletion tests/integ/modin/frame/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,12 +1142,37 @@ def test_named_agg_not_supported_axis_1(numeric_native_df):
),
param(
None,
{"x": ("A", np.exp), "y": pd.NamedAgg("C", sum)},
{
"x": ("A", np.exp),
"y": pd.NamedAgg("C", sum),
},
"Snowpark pandas aggregate does not yet support the aggregation "
+ "new_label=\\(label, np\\.exp\\), new_label=\\(label, <built-in function sum>\\)"
+ " with the given arguments",
id="named_agg",
),
param(
None,
{
"x": ("A", "first"),
"y": pd.NamedAgg("C", sum),
},
"Snowpark pandas aggregate does not yet support the aggregation "
+ "new_label=\\(label, 'first'\\), new_label=\\(label, <built-in function sum>\\)"
+ " with the given arguments",
id="named_agg",
),
param(
None,
{
"x": ("A", "last"),
"y": pd.NamedAgg("C", sum),
},
"Snowpark pandas aggregate does not yet support the aggregation "
+ "new_label=\\(label, 'last'\\), new_label=\\(label, <built-in function sum>\\)"
+ " with the given arguments",
id="named_agg",
),
],
)
@sql_count_checker(query_count=0)
Expand Down
Loading

0 comments on commit eea9995

Please sign in to comment.