Skip to content

Commit

Permalink
[SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes to introduce the new API `applyInPandasWithState` in PySpark, which provides the functionality to perform arbitrary stateful processing in Structured Streaming.

This will be a pair API with applyInPandas - applyInPandas in PySpark covers the use case of flatMapGroups in Scala/Java API, applyInPandasWithState in PySpark covers the use case of flatMapGroupsWithState in Scala/Java API.

The signature of API follows:

```
# call this function after groupBy
def applyInPandasWithState(
    self,
    func: "PandasGroupedMapFunctionWithState",
    outputStructType: Union[StructType, str],
    stateStructType: Union[StructType, str],
    outputMode: str,
    timeoutConf: str,
) -> DataFrame
```

and the signature of user function follows:

```
def func(
    key: Tuple,
    pdf_iter: Iterator[pandas.DataFrame],
    state: GroupStateImpl
) -> Iterator[pandas.DataFrame]
```

(Please refer the code diff for function doc of new function.)

Major design choices which differ from existing APIs:

1. The new API is untyped, while flatMapGroupsWithState in typed API.

This is based on the nature of Python language - it's really duck typing and type definition is just a hint. We don't have the implementation of typed API for PySpark DataFrame.

This leads us to design the API to be untyped, meaning, all types for (input, state, output) should be Row-compatible. While we don't require end users to deal with `Row` directly, the model they will use for state and output must be convertible to Row with default encoder. If they want the python type for state which is not compatible with Row (e.g. custom class), they need to pickle and use BinaryType to store it.

This requires end users to specify the type of state and output via Spark SQL schema in the method.

Note that this helps to ensure compatibility for state data across Spark versions, as long as the encoders for 1) python type -> python Row and 2) python Row -> UnsafeRow are not changed. We won't change the underlying data layout for UnsafeRow, as it will break all of existing stateful query.

2. The new API will produce Pandas DataFrame to user function, while flatMapGroupsWithState produces iterator of rows.

We decided to follow the user experience applyInPandas provides for both consistency and performance (Arrow batching, vectorization, etc). This leads us to design the user function to leverage pandas DataFrame rather than iterator of rows. While this leads inconsistency of the UX from the Scala/Java API, we don't think this will come up as a problem since Pandas is considered as de-facto standard for Python data scientists.

3. The new API will produce iterator of Pandas DataFrame to user function and also require to return iterator of Pandas DataFrame to address scalability.

There is known limitation of applyInPandas, scalability. It basically requires data in a specific group to be fit into memory. During the design phase of new API, we decided to address the scalability rather than inheriting the limitation.

To address the scalability, we tweak the user function to receive an iterator (generator) of Pandas DataFrame instead of a single Pandas DataFrame, and also return an iterator (generator) of Pandas DataFrame. We think it does not hurt the UX too much, as for-each and yield would be enough to deal with the requirement of dealing with iterator.

Implementation perspective, we split the data in a specific group to multiple chunks, which each chunk is stored and sent as "an" Arrow RecordBatch, and then finally materialized to "a" pandas DataFrame. This way, as long as end users don't materialize lots of pandas DataFrames from the iterator at the same time, only one chunk will be materialized into memory which is scalable. Similar logic applies to the output of user function, hence scalable as well.

4. The new API also bin-packs the data with multiple groups into "an" Arrow RecordBatch.

Given the API is mainly used for streaming workload, it could be high likely that the volume of data in a specific group may not be huge enough to leverage the benefit of Arrow columnar batching, which would hurt the performance. To address this, we also do the opposite thing what we do for scalability, bin-pack. That said, an Arrow RecordBatch can contain data for multiple groups, as well as a part of data for specific group. This address both aspects of concerns together, scalability and performance.

Note that we are not implementing all of features Scala/Java API provide from the initial phase. e.g. Support for batch query and support for initial state will be left as TODO.

### Why are the changes needed?

PySpark users don't have a way to perform arbitrary stateful processing in Structured Streaming and being forced to use either Java or Scala which is unacceptable for users in many cases. This PR enables PySpark users to deal with it without moving to Java/Scala world.

### Does this PR introduce _any_ user-facing change?

Yes. We are exposing new public API in PySpark which performs arbitrary stateful processing.

### How was this patch tested?

N/A. We will make sure test suites are constructed via E2E manner under [SPARK-40431](https://issues.apache.org/jira/browse/SPARK-40431) - apache#37894

Closes apache#37893 from HeartSaVioR/SPARK-40434-on-top-of-SPARK-40433-SPARK-40432.

Lead-authored-by: Jungtaek Lim <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HeartSaVioR and HyukjinKwon committed Sep 22, 2022
1 parent db51ec6 commit 603dc50
Show file tree
Hide file tree
Showing 23 changed files with 1,599 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ private[spark] object PythonEvalType {
val SQL_MAP_PANDAS_ITER_UDF = 205
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
val SQL_MAP_ARROW_ITER_UDF = 207
val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208

def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
Expand All @@ -65,6 +66,7 @@ private[spark] object PythonEvalType {
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
}
}

Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
PandasMapIterUDFType,
PandasCogroupedMapUDFType,
ArrowMapIterUDFType,
PandasGroupedMapUDFWithStateType,
)
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import AtomicType, StructType
Expand Down Expand Up @@ -147,6 +148,7 @@ class PythonEvalType:
SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205
SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206
SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207
SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208


def portable_hash(x: Hashable) -> int:
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ from typing_extensions import Protocol, Literal
from types import FunctionType

from pyspark.sql._typing import LiteralType
from pyspark.sql.streaming.state import GroupState
from pandas.core.frame import DataFrame as PandasDataFrame
from pandas.core.series import Series as PandasSeries
from numpy import ndarray as NDArray
Expand All @@ -51,6 +52,7 @@ PandasScalarIterUDFType = Literal[204]
PandasMapIterUDFType = Literal[205]
PandasCogroupedMapUDFType = Literal[206]
ArrowMapIterUDFType = Literal[207]
PandasGroupedMapUDFWithStateType = Literal[208]

class PandasVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ...
Expand Down Expand Up @@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[
Callable[[Any, DataFrameLike], DataFrameLike],
]

PandasGroupedMapFunctionWithState = Callable[
[Any, Iterable[DataFrameLike], GroupState], Iterable[DataFrameLike]
]

class PandasVariadicGroupedAggFunction(Protocol):
def __call__(self, *_: SeriesLike) -> LiteralType: ...

Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
None,
]: # None means it should infer the type from type hints.

Expand Down Expand Up @@ -402,6 +403,7 @@ def _create_pandas_udf(f, returnType, evalType):
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
]:
# In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered
# at `apply` instead.
Expand Down
125 changes: 123 additions & 2 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@
# limitations under the License.
#
import sys
from typing import List, Union, TYPE_CHECKING
from typing import List, Union, TYPE_CHECKING, cast
import warnings

from pyspark.rdd import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StructType
from pyspark.sql.streaming.state import GroupStateTimeout
from pyspark.sql.types import StructType, _parse_datatype_string

if TYPE_CHECKING:
from pyspark.sql.pandas._typing import (
GroupedMapPandasUserDefinedFunction,
PandasGroupedMapFunction,
PandasGroupedMapFunctionWithState,
PandasCogroupedMapFunction,
)
from pyspark.sql.group import GroupedData
Expand Down Expand Up @@ -216,6 +218,125 @@ def applyInPandas(
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.session)

def applyInPandasWithState(
self,
func: "PandasGroupedMapFunctionWithState",
outputStructType: Union[StructType, str],
stateStructType: Union[StructType, str],
outputMode: str,
timeoutConf: str,
) -> DataFrame:
"""
Applies the given function to each group of data, while maintaining a user-defined
per-group state. The result Dataset will represent the flattened record returned by the
function.
For a streaming Dataset, the function will be invoked first for all input groups and then
for all timed out states where the input data is set to be empty. Updates to each group's
state will be saved across invocations.
The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and
return another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
:class:`pyspark.sql.streaming.state.GroupState`.
For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
and the returned `pandas.DataFrame` across all invocations are combined as a
:class:`DataFrame`. Note that the user function should not make a guess of the number of
elements in the iterator. To process all data, the user function needs to iterate all
elements and process them. On the other hand, the user function is not strictly required to
iterate through all elements in the iterator if it intends to read a part of data.
The `outputStructType` should be a :class:`StructType` describing the schema of all
elements in the returned value, `pandas.DataFrame`. The column labels of all elements in
returned `pandas.DataFrame` must either match the field names in the defined schema if
specified as strings, or match the field data types by position if not strings,
e.g. integer indices.
The `stateStructType` should be :class:`StructType` describing the schema of the
user-defined state. The value of the state will be presented as a tuple, as well as the
update should be performed with the tuple. The corresponding Python types for
:class:DataType are supported. Please refer to the page
https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab).
The size of each DataFrame in both the input and output can be arbitrary. The number of
DataFrames in both the input and output can also be arbitrary.
.. versionadded:: 3.4.0
Parameters
----------
func : function
a Python native function to be called on every group. It should take parameters
(key, Iterator[`pandas.DataFrame`], state) and return Iterator[`pandas.DataFrame`].
Note that the type of the key is tuple and the type of the state is
:class:`pyspark.sql.streaming.state.GroupState`.
outputStructType : :class:`pyspark.sql.types.DataType` or str
the type of the output records. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
stateStructType : :class:`pyspark.sql.types.DataType` or str
the type of the user-defined state. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
outputMode : str
the output mode of the function.
timeoutConf : str
timeout configuration for groups that do not receive data for a while. valid values
are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`.
Examples
--------
>>> import pandas as pd # doctest: +SKIP
>>> from pyspark.sql.streaming.state import GroupStateTimeout
>>> def count_fn(key, pdf_iter, state):
... assert isinstance(state, GroupStateImpl)
... total_len = 0
... for pdf in pdf_iter:
... total_len += len(pdf)
... state.update((total_len,))
... yield pd.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]})
>>> df.groupby("id").applyInPandasWithState(
... count_fn, outputStructType="id long, countAsString string",
... stateStructType="len long", outputMode="Update",
... timeoutConf=GroupStateTimeout.NoTimeout) # doctest: +SKIP
Notes
-----
This function requires a full shuffle.
This API is experimental.
"""

from pyspark.sql import GroupedData
from pyspark.sql.functions import pandas_udf

assert isinstance(self, GroupedData)
assert timeoutConf in [
GroupStateTimeout.NoTimeout,
GroupStateTimeout.ProcessingTimeTimeout,
GroupStateTimeout.EventTimeTimeout,
]

if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))
if isinstance(stateStructType, str):
stateStructType = cast(StructType, _parse_datatype_string(stateStructType))

udf = pandas_udf(
func, # type: ignore[call-overload]
returnType=outputStructType,
functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
)
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.applyInPandasWithState(
udf_column._jc.expr(),
self.session._jsparkSession.parseDataType(outputStructType.json()),
self.session._jsparkSession.parseDataType(stateStructType.json()),
outputMode,
timeoutConf,
)
return DataFrame(jdf, self.session)

def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
"""
Cogroups this group with another group so that we can run cogrouped operations.
Expand Down
Loading

0 comments on commit 603dc50

Please sign in to comment.