From b9659fde93dd253c5f153f930dc0b73d9e88abb5 Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Thu, 9 Jan 2025 15:31:13 -0800 Subject: [PATCH 1/8] [SNOW-1754978]: Add support for `first` and `last` in `GroupBy.apply` --- docs/source/modin/supported/agg_supp.rst | 4 + .../plugin/_internal/aggregation_utils.py | 171 ++++++++++++++++-- .../compiler/snowflake_query_compiler.py | 16 +- tests/integ/modin/frame/test_aggregate.py | 27 ++- .../modin/groupby/test_groupby_basic_agg.py | 32 ++++ 5 files changed, 235 insertions(+), 15 deletions(-) diff --git a/docs/source/modin/supported/agg_supp.rst b/docs/source/modin/supported/agg_supp.rst index 5b3a2c174c0..1cc58e8d4e5 100644 --- a/docs/source/modin/supported/agg_supp.rst +++ b/docs/source/modin/supported/agg_supp.rst @@ -60,3 +60,7 @@ the aggregation is supported by ``SeriesGroupBy.agg``. +-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ | ``len`` | ``N`` | ``N`` | ``Y`` | ``Y`` | +-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ +| ``first`` | ``N`` | ``N`` | ``Y`` | ``Y`` | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ +| ``last`` | ``N`` | ``N`` | ``Y`` | ``Y`` | ++-----------------------------+-------------------------------------+----------------------------------+--------------------------------------------+-----------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 72978fc797c..f85987b3c5f 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -430,6 +430,70 @@ def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable: return sum(builtin("zeroifnull")(col) for col in cols) +def _column_first_value( + column: SnowparkColumn, + row_position_snowflake_quoted_identifier: str, + ignore_nulls: bool, +) -> SnowparkColumn: + """ + Returns the first value (ordered by `row_position_snowflake_identifier`) over the specified group. + + Parameters + ---------- + col: Snowpark Column + The Snowpark column to aggregate. + row_position_snowflake_quoted_identifier: str + The Snowflake quoted identifier of the column to order by. + ignore_nulls: bool + Whether or not to ignore nulls. + + Returns + ------- + The aggregated Snowpark Column. + """ + if ignore_nulls: + col_to_min_by = iff( + col(column).is_null(), + pandas_lit(None), + row_position_snowflake_quoted_identifier, + ) + else: + col_to_min_by = col(row_position_snowflake_quoted_identifier) + return builtin("min_by")(col(column), col_to_min_by) + + +def _column_last_value( + column: SnowparkColumn, + row_position_snowflake_quoted_identifier: str, + ignore_nulls: bool, +) -> SnowparkColumn: + """ + Returns the last value (ordered by `row_position_snowflake_identifier`) over the specified group. + + Parameters + ---------- + col: Snowpark Column + The Snowpark column to aggregate. + row_position_snowflake_quoted_identifier: str + The Snowflake quoted identifier of the column to order by. + ignore_nulls: bool + Whether or not to ignore nulls. + + Returns + ------- + The aggregated Snowpark Column. + """ + if ignore_nulls: + col_to_max_by = iff( + col(column).is_null(), + pandas_lit(None), + row_position_snowflake_quoted_identifier, + ) + else: + col_to_max_by = col(row_position_snowflake_quoted_identifier) + return builtin("max_by")(col(column), col_to_max_by) + + def _create_pandas_to_snowpark_pandas_aggregation_map( pandas_functions: Iterable[AggFuncTypeBase], snowpark_pandas_aggregation: _SnowparkPandasAggregation, @@ -469,6 +533,18 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( preserves_snowpark_pandas_types=False, ), ), + "first": _SnowparkPandasAggregation( + axis_0_aggregation=_column_first_value, + axis_1_aggregation_keepna=lambda *cols: cols[0], + axis_1_aggregation_skipna=lambda *cols: coalesce(*cols), + preserves_snowpark_pandas_types=True, + ), + "last": _SnowparkPandasAggregation( + axis_0_aggregation=_column_last_value, + axis_1_aggregation_keepna=lambda *cols: cols[-1], + axis_1_aggregation_skipna=lambda *cols: coalesce(*(cols[::-1])), + preserves_snowpark_pandas_types=True, + ), **_create_pandas_to_snowpark_pandas_aggregation_map( ("mean", np.mean), _SnowparkPandasAggregation( @@ -610,7 +686,10 @@ def is_snowflake_agg_func(agg_func: AggFuncTypeBase) -> bool: def get_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] + agg_func: AggFuncTypeBase, + agg_kwargs: dict[str, Any], + axis: Literal[0, 1], + _is_df_agg: bool = False, ) -> Optional[SnowflakeAggFunc]: """ Get the corresponding Snowflake/Snowpark aggregation function for the given aggregation function. @@ -659,6 +738,23 @@ def get_snowflake_agg_func( def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn: return column_quantile(col, interpolation, q) + elif ( + snowpark_aggregation == _column_first_value + or snowpark_aggregation == _column_last_value + ): + # First and last are not supported for df.agg. + if _is_df_agg: + return None + ignore_nulls = agg_kwargs.get("skipna", True) + row_position_snowflake_quoted_identifier = agg_kwargs.get( + "_first_last_row_pos_col", None + ) + snowpark_aggregation = functools.partial( + snowpark_aggregation, + ignore_nulls=ignore_nulls, + row_position_snowflake_quoted_identifier=row_position_snowflake_quoted_identifier, + ) + assert ( snowpark_aggregation is not None ), "Internal error: Snowpark pandas should have identified a Snowpark aggregation." @@ -707,7 +803,10 @@ def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn: def _is_supported_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] + agg_func: AggFuncTypeBase, + agg_kwargs: dict[str, Any], + axis: Literal[0, 1], + _is_df_agg: bool = False, ) -> bool: """ check if the aggregation function is supported with snowflake. Current supported @@ -724,11 +823,14 @@ def _is_supported_snowflake_agg_func( # For named aggregations, like `df.agg(new_col=("old_col", "sum"))`, # take the second part of the named aggregation. agg_func = agg_func[0] - return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None + return get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is not None def _are_all_agg_funcs_supported_by_snowflake( - agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: Literal[0, 1] + agg_funcs: list[AggFuncTypeBase], + agg_kwargs: dict[str, Any], + axis: Literal[0, 1], + _is_df_agg: bool = False, ) -> bool: """ Check if all aggregation functions in the given list are snowflake supported @@ -739,7 +841,8 @@ def _are_all_agg_funcs_supported_by_snowflake( return False. """ return all( - _is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs + _is_supported_snowflake_agg_func(func, agg_kwargs, axis, _is_df_agg) + for func in agg_funcs ) @@ -747,6 +850,7 @@ def check_is_aggregation_supported_in_snowflake( agg_func: AggFuncType, agg_kwargs: dict[str, Any], axis: Literal[0, 1], + _is_df_agg: bool = False, ) -> bool: """ check if distributed implementation with snowflake is available for the aggregation @@ -756,6 +860,8 @@ def check_is_aggregation_supported_in_snowflake( agg_func: the aggregation function to apply agg_kwargs: keyword argument passed for the aggregation function, such as ddof, min_count etc. The value can be different for different aggregation function. + _is_df_agg: whether or not this is being called by df.agg, since some functions are onlyu supported + for groupby_agg. Returns: bool Whether the aggregation operation can be executed with snowflake sql engine. @@ -765,15 +871,21 @@ def check_is_aggregation_supported_in_snowflake( if is_dict_like(agg_func): return all( ( - _are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) + _are_all_agg_funcs_supported_by_snowflake( + value, agg_kwargs, axis, _is_df_agg + ) if is_list_like(value) and not is_named_tuple(value) - else _is_supported_snowflake_agg_func(value, agg_kwargs, axis) + else _is_supported_snowflake_agg_func( + value, agg_kwargs, axis, _is_df_agg + ) ) for value in agg_func.values() ) elif is_list_like(agg_func): - return _are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) - return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) + return _are_all_agg_funcs_supported_by_snowflake( + agg_func, agg_kwargs, axis, _is_df_agg + ) + return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) def _is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: @@ -1372,10 +1484,16 @@ def repr_aggregate_function(agg_func: AggFuncType, agg_kwargs: Mapping) -> str: if using_named_aggregations_for_func(agg_func): # New axis labels are sensitive, so replace them with "new_label." # Existing axis labels are sensitive, so replace them with "label." - return ", ".join( - f"new_label=(label, {repr_aggregate_function(f, agg_kwargs)})" - for _, f in agg_kwargs.values() - ) + if is_list_like(list(agg_kwargs.values())[0]): + return ", ".join( + f"new_label=(label, {repr_aggregate_function(f, agg_kwargs)})" + for _, f in agg_kwargs.values() + ) + else: + return ", ".join( + f"new_label=(label, {repr_aggregate_function(f, agg_kwargs)})" + for f in agg_kwargs.values() + ) if isinstance(agg_func, str): # Strings functions represent names of pandas functions, e.g. # "sum" means to aggregate with pandas.Series.sum. string function @@ -1422,3 +1540,30 @@ def repr_aggregate_function(agg_func: AggFuncType, agg_kwargs: Mapping) -> str: # exposing sensitive user input in the NotImplemented error message and # thus in telemetry. return "Callable" + + +def is_first_last_in_agg_funcs( + column_to_agg_func: dict[str, Union[list[AggFuncInfo], AggFuncInfo]] +) -> bool: + """ + Helper function to check if the `first` or `last` aggregation functions have been specified. + + Parameters + ---------- + column_to_agg_func: dict[str, Union[list[AggFuncInfo], AggFuncInfo]] + The mapping of column name to aggregation function (or functions) to apply. + + Returns + ------- + bool + Whether any of the functions to apply are either `first` or `last`. + """ + + def _is_first_last_agg_func(value: AggFuncInfo) -> bool: + return value.func in ["first", "last"] + + return any( + (isinstance(val, AggFuncInfo) and _is_first_last_agg_func(val)) + or (isinstance(val, list) and any(_is_first_last_agg_func(v) for v in val)) + for val in column_to_agg_func.values() + ) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index c9171d9369c..673e6518475 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -173,6 +173,7 @@ get_agg_func_to_col_map, get_pandas_aggr_func_name, get_snowflake_agg_func, + is_first_last_in_agg_funcs, repr_aggregate_function, using_named_aggregations_for_func, ) @@ -3796,9 +3797,20 @@ def convert_func_to_agg_func_info( internal_frame.index_column_snowflake_quoted_identifiers ) + # We need to check if `first` or `last` are in the aggregation functions, + # as we need to ensure a row position column and pass it in as an agg_kwarg + # if it is (for the window function). + first_last_present = is_first_last_in_agg_funcs(column_to_agg_func) + if first_last_present: + internal_frame = internal_frame.ensure_row_position_column() + agg_kwargs[ + "_first_last_row_pos_col" + ] = internal_frame.row_position_snowflake_quoted_identifier agg_col_ops, new_data_column_index_names = generate_column_agg_info( internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby ) + if first_last_present: + agg_kwargs.pop("_first_last_row_pos_col") # the pandas label and quoted identifier generated for each result column # after aggregation will be used as new pandas label and quoted identifiers. new_data_column_pandas_labels = [] @@ -6026,7 +6038,9 @@ def agg( # by snowflake engine. # If we are using Named Aggregations, we need to do our supported check slightly differently. uses_named_aggs = using_named_aggregations_for_func(func) - if not check_is_aggregation_supported_in_snowflake(func, kwargs, axis): + if not check_is_aggregation_supported_in_snowflake( + func, kwargs, axis, _is_df_agg=True + ): ErrorMessage.not_implemented( f"Snowpark pandas aggregate does not yet support the aggregation {repr_aggregate_function(func, kwargs)} with the given arguments." ) diff --git a/tests/integ/modin/frame/test_aggregate.py b/tests/integ/modin/frame/test_aggregate.py index 30e1bc876b7..bb934e1d4a2 100644 --- a/tests/integ/modin/frame/test_aggregate.py +++ b/tests/integ/modin/frame/test_aggregate.py @@ -1141,12 +1141,37 @@ def test_named_agg_not_supported_axis_1(numeric_native_df): ), param( None, - {"x": ("A", np.exp), "y": pd.NamedAgg("C", sum)}, + { + "x": ("A", np.exp), + "y": pd.NamedAgg("C", sum), + }, "Snowpark pandas aggregate does not yet support the aggregation " + "new_label=\\(label, np\\.exp\\), new_label=\\(label, \\)" + " with the given arguments", id="named_agg", ), + param( + None, + { + "x": ("A", "first"), + "y": pd.NamedAgg("C", sum), + }, + "Snowpark pandas aggregate does not yet support the aggregation " + + "new_label=\\(label, 'first'\\), new_label=\\(label, \\)" + + " with the given arguments", + id="named_agg", + ), + param( + None, + { + "x": ("A", "last"), + "y": pd.NamedAgg("C", sum), + }, + "Snowpark pandas aggregate does not yet support the aggregation " + + "new_label=\\(label, 'last'\\), new_label=\\(label, \\)" + + " with the given arguments", + id="named_agg", + ), ], ) @sql_count_checker(query_count=0) diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py index ff4636a8bd9..4e3e874445d 100644 --- a/tests/integ/modin/groupby/test_groupby_basic_agg.py +++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py @@ -1149,6 +1149,38 @@ def test_valid_func_valid_kwarg_should_work(basic_snowpark_pandas_df): ) +@pytest.mark.parametrize("skipna", [True, False]) +@sql_count_checker(query_count=1) +def test_groupby_agg_first(skipna): + native_df = native_pd.DataFrame( + {"grp_col": ["A", "A", "B", "B", "A"], "float_col": [np.nan, 2, 3, np.nan, 4]} + ) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(by="grp_col").agg( + {"float_col": ["quantile", "first"]}, skipna=skipna + ), + ) + + +@pytest.mark.parametrize("skipna", [True, False]) +@sql_count_checker(query_count=1) +def test_groupby_agg_last(skipna): + native_df = native_pd.DataFrame( + {"grp_col": ["A", "A", "B", "B", "A"], "float_col": [np.nan, 2, 3, np.nan, 4]} + ) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(by="grp_col").agg( + {"float_col": ["quantile", "last"]}, skipna=skipna + ), + ) + + class TestTimedelta: @sql_count_checker(query_count=1) @pytest.mark.parametrize( From 768c523ac4b60b76a7e787d93bdb3465064a5ce5 Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Thu, 9 Jan 2025 15:33:49 -0800 Subject: [PATCH 2/8] Fix typo in comment --- .../snowpark/modin/plugin/_internal/aggregation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index f85987b3c5f..47638f5e7a8 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -860,7 +860,7 @@ def check_is_aggregation_supported_in_snowflake( agg_func: the aggregation function to apply agg_kwargs: keyword argument passed for the aggregation function, such as ddof, min_count etc. The value can be different for different aggregation function. - _is_df_agg: whether or not this is being called by df.agg, since some functions are onlyu supported + _is_df_agg: whether or not this is being called by df.agg, since some functions are only supported for groupby_agg. Returns: bool From 5a370bd32ddec4b98e1510e9f8765a5761cf04df Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Thu, 9 Jan 2025 15:34:47 -0800 Subject: [PATCH 3/8] Fix typo in comment --- .../snowpark/modin/plugin/compiler/snowflake_query_compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index a8a0bb2ce61..3e4fd72512b 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -3799,7 +3799,7 @@ def convert_func_to_agg_func_info( # We need to check if `first` or `last` are in the aggregation functions, # as we need to ensure a row position column and pass it in as an agg_kwarg - # if it is (for the window function). + # if it is (for the min_by/max_by function). first_last_present = is_first_last_in_agg_funcs(column_to_agg_func) if first_last_present: internal_frame = internal_frame.ensure_row_position_column() From 08234ded9bd79d01444837778df2246eb5d0f57b Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Thu, 9 Jan 2025 15:36:19 -0800 Subject: [PATCH 4/8] Add CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ebf01e80d1..5443c5dee4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,7 @@ - %X: Locale’s appropriate time representation. - %%: A literal '%' character. - Added support for `Series.between`. +- Added support for `first` and `last` in `DataFrameGroupBy.apply` and `SeriesGroupBy.apply`. #### Bug Fixes From 5ba270d1bc7f61cb7008ed21e45adc08c7089f98 Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Thu, 9 Jan 2025 15:38:28 -0800 Subject: [PATCH 5/8] Make comment less confusing --- .../snowpark/modin/plugin/_internal/aggregation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 47638f5e7a8..352c157471d 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -742,8 +742,8 @@ def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn: snowpark_aggregation == _column_first_value or snowpark_aggregation == _column_last_value ): - # First and last are not supported for df.agg. if _is_df_agg: + # First and last are not supported for df.agg. return None ignore_nulls = agg_kwargs.get("skipna", True) row_position_snowflake_quoted_identifier = agg_kwargs.get( From 2810de9a40d925666686abe3abb04658e7947e03 Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Thu, 9 Jan 2025 16:05:16 -0800 Subject: [PATCH 6/8] Add comment to make code less confusing --- .../snowpark/modin/plugin/_internal/aggregation_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 352c157471d..46cd9c620ac 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -1484,6 +1484,11 @@ def repr_aggregate_function(agg_func: AggFuncType, agg_kwargs: Mapping) -> str: if using_named_aggregations_for_func(agg_func): # New axis labels are sensitive, so replace them with "new_label." # Existing axis labels are sensitive, so replace them with "label." + # This is checking whether the named aggregations are for a DataFrame, + # in which case they are of the format new_col_name = (col_to_operate_on, + # function), or for a Series, in which case they are of the format + # new_col_name=function, in order to ensure we parse the functions out + # from the keyword args correctly. if is_list_like(list(agg_kwargs.values())[0]): return ", ".join( f"new_label=(label, {repr_aggregate_function(f, agg_kwargs)})" From ccfcec28c45d183778c6f7cf11c8fb1f62aa9af0 Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Thu, 9 Jan 2025 16:52:31 -0800 Subject: [PATCH 7/8] Fix typo in CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5443c5dee4e..082d6198f28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,7 +49,7 @@ - %X: Locale’s appropriate time representation. - %%: A literal '%' character. - Added support for `Series.between`. -- Added support for `first` and `last` in `DataFrameGroupBy.apply` and `SeriesGroupBy.apply`. +- Added support for `first` and `last` in `DataFrameGroupBy.agg` and `SeriesGroupBy.agg`. #### Bug Fixes From e9efea462608751b892f91fad75f8c5533192cbf Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Wed, 15 Jan 2025 15:07:43 -0800 Subject: [PATCH 8/8] fix missing column escapes --- .../plugin/_internal/aggregation_utils.py | 4 +-- .../modin/groupby/test_groupby_basic_agg.py | 34 ++++++++++--------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 02091401f9d..599c0233c68 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -456,7 +456,7 @@ def _column_first_value( col_to_min_by = iff( col(column).is_null(), pandas_lit(None), - row_position_snowflake_quoted_identifier, + col(row_position_snowflake_quoted_identifier), ) else: col_to_min_by = col(row_position_snowflake_quoted_identifier) @@ -488,7 +488,7 @@ def _column_last_value( col_to_max_by = iff( col(column).is_null(), pandas_lit(None), - row_position_snowflake_quoted_identifier, + col(row_position_snowflake_quoted_identifier), ) else: col_to_max_by = col(row_position_snowflake_quoted_identifier) diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py index 0500e6f83ad..a1c234b1af4 100644 --- a/tests/integ/modin/groupby/test_groupby_basic_agg.py +++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py @@ -1151,34 +1151,36 @@ def test_valid_func_valid_kwarg_should_work(basic_snowpark_pandas_df): @pytest.mark.parametrize("skipna", [True, False]) +@pytest.mark.parametrize("op", ["first", "last"]) @sql_count_checker(query_count=1) -def test_groupby_agg_first(skipna): +def test_groupby_agg_first_and_last(skipna, op): native_df = native_pd.DataFrame( {"grp_col": ["A", "A", "B", "B", "A"], "float_col": [np.nan, 2, 3, np.nan, 4]} ) snow_df = pd.DataFrame(native_df) - eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: df.groupby(by="grp_col").agg( - {"float_col": ["quantile", "first"]}, skipna=skipna - ), - ) + def comparator(snow_result, native_result): + # When passing a list of aggregations, native pandas does not respect kwargs, so skipna + # is always treated as the default value (true). + # Massage the native results to match the expected behavior. + if not skipna: + if op == "first": + assert native_result["float_col", "first"]["A"] == 2 + assert native_result["float_col", "first"]["B"] == 3 + native_result["float_col", "first"]["A"] = None + else: + assert native_result["float_col", "last"]["A"] == 4 + assert native_result["float_col", "last"]["B"] == 3 + native_result["float_col", "last"]["B"] = None + assert_snowpark_pandas_equal_to_pandas(snow_result, native_result) -@pytest.mark.parametrize("skipna", [True, False]) -@sql_count_checker(query_count=1) -def test_groupby_agg_last(skipna): - native_df = native_pd.DataFrame( - {"grp_col": ["A", "A", "B", "B", "A"], "float_col": [np.nan, 2, 3, np.nan, 4]} - ) - snow_df = pd.DataFrame(native_df) eval_snowpark_pandas_result( snow_df, native_df, lambda df: df.groupby(by="grp_col").agg( - {"float_col": ["quantile", "last"]}, skipna=skipna + {"float_col": ["quantile", op]}, skipna=skipna ), + comparator=comparator, )