From f2ba6b55c1d9ff2aea370e95db87bbcfc5164772 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 27 Sep 2022 14:40:41 +0800 Subject: [PATCH] [SPARK-40573][PS] Make `ddof` in `GroupBy.std`, `GroupBy.var` and `GroupBy.sem` accept arbitary integers ### What changes were proposed in this pull request? Make `ddof` in `GroupBy.std`, `GroupBy.var` and `GroupBy.sem` accept arbitary integers ### Why are the changes needed? for API coverage ### Does this PR introduce _any_ user-facing change? yes, can not accept non-{0,1} `ddof` ### How was this patch tested? added testsutes Closes #38009 from zhengruifeng/ps_groupby_ddof. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/pandas/groupby.py | 41 +++++++++++++-------- python/pyspark/pandas/tests/test_groupby.py | 2 +- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 7085d2ec05914..9378e83af9087 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -722,12 +722,17 @@ def std(self, ddof: int = 1) -> FrameLike: """ Compute standard deviation of groups, excluding missing values. + .. versionadded:: 3.3.0 + Parameters ---------- ddof : int, default 1 Delta Degrees of Freedom. The divisor used in calculations is N - ddof, where N represents the number of elements. + .. versionchanged:: 3.4.0 + Supported including arbitary integers. + Examples -------- >>> df = ps.DataFrame({"A": [1, 2, 1, 2], "B": [True, False, False, True], @@ -744,7 +749,8 @@ def std(self, ddof: int = 1) -> FrameLike: pyspark.pandas.Series.groupby pyspark.pandas.DataFrame.groupby """ - assert ddof in (0, 1) + if not isinstance(ddof, int): + raise TypeError("ddof must be integer") # Raise the TypeError when all aggregation columns are of unaccepted data types any_accepted = any( @@ -756,8 +762,11 @@ def std(self, ddof: int = 1) -> FrameLike: "Unaccepted data types of aggregation columns; numeric or bool expected." ) + def std(col: Column) -> Column: + return SF.stddev(col, ddof) + return self._reduce_for_stat_function( - F.stddev_pop if ddof == 0 else F.stddev_samp, + std, accepted_spark_types=(NumericType,), bool_to_numeric=True, ) @@ -791,12 +800,17 @@ def var(self, ddof: int = 1) -> FrameLike: """ Compute variance of groups, excluding missing values. + .. versionadded:: 3.3.0 + Parameters ---------- ddof : int, default 1 Delta Degrees of Freedom. The divisor used in calculations is N - ddof, where N represents the number of elements. + .. versionchanged:: 3.4.0 + Supported including arbitary integers. + Examples -------- >>> df = ps.DataFrame({"A": [1, 2, 1, 2], "B": [True, False, False, True], @@ -813,10 +827,14 @@ def var(self, ddof: int = 1) -> FrameLike: pyspark.pandas.Series.groupby pyspark.pandas.DataFrame.groupby """ - assert ddof in (0, 1) + if not isinstance(ddof, int): + raise TypeError("ddof must be integer") + + def var(col: Column) -> Column: + return SF.var(col, ddof) return self._reduce_for_stat_function( - F.var_pop if ddof == 0 else F.var_samp, + var, accepted_spark_types=(NumericType,), bool_to_numeric=True, ) @@ -963,8 +981,8 @@ def sem(self, ddof: int = 1) -> FrameLike: pyspark.pandas.Series.sem pyspark.pandas.DataFrame.sem """ - if ddof not in [0, 1]: - raise TypeError("ddof must be 0 or 1") + if not isinstance(ddof, int): + raise TypeError("ddof must be integer") # Raise the TypeError when all aggregation columns are of unaccepted data types any_accepted = any( @@ -976,15 +994,8 @@ def sem(self, ddof: int = 1) -> FrameLike: "Unaccepted data types of aggregation columns; numeric or bool expected." ) - if ddof == 0: - - def sem(col: Column) -> Column: - return F.stddev_pop(col) / F.sqrt(F.count(col)) - - else: - - def sem(col: Column) -> Column: - return F.stddev_samp(col) / F.sqrt(F.count(col)) + def sem(col: Column) -> Column: + return SF.stddev(col, ddof) / F.sqrt(F.count(col)) return self._reduce_for_stat_function( sem, diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py index f0b3a04be177e..481a0f8cfac92 100644 --- a/python/pyspark/pandas/tests/test_groupby.py +++ b/python/pyspark/pandas/tests/test_groupby.py @@ -3111,7 +3111,7 @@ def test_ddof(self): ) psdf = ps.from_pandas(pdf) - for ddof in (0, 1): + for ddof in [-1, 0, 1, 2, 3]: # std self.assert_eq( pdf.groupby("a").std(ddof=ddof).sort_index(),