diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 3db2813b63040..c7c3180b88514 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3919,6 +3919,12 @@ ], "sqlState" : "42802" }, + "STATEFUL_PROCESSOR_UNKNOWN_TIME_MODE" : { + "message" : [ + "Unknown time mode . Accepted timeMode modes are 'none', 'processingTime', 'eventTime'" + ], + "sqlState" : "42802" + }, "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : { "message" : [ "Failed to create column family with unsupported starting character and name=." diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index b84f7d839c2aa..a7e4f186000b5 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -749,6 +749,7 @@ private[spark] object LogKeys { case object START_INDEX extends LogKey case object START_TIME extends LogKey case object STATEMENT_ID extends LogKey + case object STATE_NAME extends LogKey case object STATE_STORE_ID extends LogKey case object STATE_STORE_PROVIDER extends LogKey case object STATE_STORE_VERSION extends LogKey diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 6a67587fbd80c..fbddd7a50dc67 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -61,6 +61,7 @@ private[spark] object PythonEvalType { val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208 val SQL_GROUPED_MAP_ARROW_UDF = 209 val SQL_COGROUPED_MAP_ARROW_UDF = 210 + val SQL_TRANSFORM_WITH_STATE_PANDAS_UDF = 211 val SQL_TABLE_UDF = 300 val SQL_ARROW_TABLE_UDF = 301 @@ -82,6 +83,7 @@ private[spark] object PythonEvalType { case SQL_COGROUPED_MAP_ARROW_UDF => "SQL_COGROUPED_MAP_ARROW_UDF" case SQL_TABLE_UDF => "SQL_TABLE_UDF" case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF" + case SQL_TRANSFORM_WITH_STATE_PANDAS_UDF => "SQL_TRANSFORM_WITH_STATE_PANDAS_UDF" } } diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 677381704427c..9925ae406dbd9 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -68,4 +68,6 @@ files="src/main/java/org/apache/spark/network/util/LimitedInputStream.java" /> + diff --git a/python/docs/source/reference/pyspark.sql/index.rst b/python/docs/source/reference/pyspark.sql/index.rst index 93901ab7ce12e..36618af2de2c2 100644 --- a/python/docs/source/reference/pyspark.sql/index.rst +++ b/python/docs/source/reference/pyspark.sql/index.rst @@ -44,3 +44,4 @@ This page gives an overview of all public Spark SQL API. variant_val protobuf datasource + stateful_processor diff --git a/python/docs/source/reference/pyspark.sql/stateful_processor.rst b/python/docs/source/reference/pyspark.sql/stateful_processor.rst new file mode 100644 index 0000000000000..f97754b35d7c4 --- /dev/null +++ b/python/docs/source/reference/pyspark.sql/stateful_processor.rst @@ -0,0 +1,29 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +================== +Stateful Processor +================== +.. currentmodule:: pyspark.sql.streaming + +.. autosummary:: + :toctree: api/ + + StatefulProcessor.init + StatefulProcessor.handleInputRows + StatefulProcessor.close \ No newline at end of file diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 0838f446279b9..f2e1af2c2ae43 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -55,6 +55,7 @@ ArrowMapIterUDFType = Literal[207] PandasGroupedMapUDFWithStateType = Literal[208] ArrowGroupedMapUDFType = Literal[209] ArrowCogroupedMapUDFType = Literal[210] +PandasGroupedMapUDFTransformWithStateType = Literal[211] class PandasVariadicScalarToScalarFunction(Protocol): def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ... diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index fdb81f571a85c..cdb5663182748 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -413,6 +413,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, None, @@ -453,6 +454,7 @@ def _validate_pandas_udf(f, evalType) -> int: PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 26b50e4c6c186..b93723fbc6254 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -15,7 +15,7 @@ # limitations under the License. # import sys -from typing import List, Union, TYPE_CHECKING, cast +from typing import Any, Iterator, List, Union, TYPE_CHECKING, cast import warnings from pyspark.errors import PySparkTypeError @@ -23,6 +23,11 @@ from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame from pyspark.sql.streaming.state import GroupStateTimeout +from pyspark.sql.streaming.stateful_processor_api_client import ( + StatefulProcessorApiClient, + StatefulProcessorHandleState, +) +from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle from pyspark.sql.types import StructType, _parse_datatype_string if TYPE_CHECKING: @@ -33,6 +38,7 @@ PandasCogroupedMapFunction, ArrowGroupedMapFunction, ArrowCogroupedMapFunction, + DataFrameLike as PandasDataFrameLike, ) from pyspark.sql.group import GroupedData @@ -358,6 +364,172 @@ def applyInPandasWithState( ) return DataFrame(jdf, self.session) + def transformWithStateInPandas( + self, + statefulProcessor: StatefulProcessor, + outputStructType: Union[StructType, str], + outputMode: str, + timeMode: str, + ) -> DataFrame: + """ + Invokes methods defined in the stateful processor used in arbitrary state API v2. It + requires protobuf, pandas and pyarrow as dependencies to process input/state data. We + allow the user to act on per-group set of input rows along with keyed state and the user + can choose to output/return 0 or more rows. + + For a streaming dataframe, we will repeatedly invoke the interface methods for new rows + in each trigger and the user's state/state variables will be stored persistently across + invocations. + + The `statefulProcessor` should be a Python class that implements the interface defined in + :class:`StatefulProcessor`. + + The `outputStructType` should be a :class:`StructType` describing the schema of all + elements in the returned value, `pandas.DataFrame`. The column labels of all elements in + returned `pandas.DataFrame` must either match the field names in the defined schema if + specified as strings, or match the field data types by position if not strings, + e.g. integer indices. + + The size of each `pandas.DataFrame` in both the input and output can be arbitrary. The + number of `pandas.DataFrame` in both the input and output can also be arbitrary. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + statefulProcessor : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessor` + Instance of StatefulProcessor whose functions will be invoked by the operator. + outputStructType : :class:`pyspark.sql.types.DataType` or str + The type of the output records. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + outputMode : str + The output mode of the stateful processor. + timeMode : str + The time mode semantics of the stateful processor for timers and TTL. + + Examples + -------- + >>> from typing import Iterator + ... + >>> import pandas as pd # doctest: +SKIP + ... + >>> from pyspark.sql import Row + >>> from pyspark.sql.functions import col, split + >>> from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle + >>> from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType + ... + >>> spark.conf.set("spark.sql.streaming.stateStore.providerClass", + ... "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") + ... # Below is a simple example to find erroneous sensors from temperature sensor data. The + ... # processor returns a count of total readings, while keeping erroneous reading counts + ... # in streaming state. A violation is defined when the temperature is above 100. + ... # The input data is a DataFrame with the following schema: + ... # `id: string, temperature: long`. + ... # The output schema and state schema are defined as below. + >>> output_schema = StructType([ + ... StructField("id", StringType(), True), + ... StructField("count", IntegerType(), True) + ... ]) + >>> state_schema = StructType([ + ... StructField("value", IntegerType(), True) + ... ]) + >>> class SimpleStatefulProcessor(StatefulProcessor): + ... def init(self, handle: StatefulProcessorHandle): + ... self.num_violations_state = handle.getValueState("numViolations", state_schema) + ... + ... def handleInputRows(self, key, rows): + ... new_violations = 0 + ... count = 0 + ... exists = self.num_violations_state.exists() + ... if exists: + ... existing_violations_row = self.num_violations_state.get() + ... existing_violations = existing_violations_row[0] + ... else: + ... existing_violations = 0 + ... for pdf in rows: + ... pdf_count = pdf.count() + ... count += pdf_count.get('temperature') + ... violations_pdf = pdf.loc[pdf['temperature'] > 100] + ... new_violations += violations_pdf.count().get('temperature') + ... updated_violations = new_violations + existing_violations + ... self.num_violations_state.update((updated_violations,)) + ... yield pd.DataFrame({'id': key, 'count': count}) + ... + ... def close(self) -> None: + ... pass + + Input DataFrame: + +---+-----------+ + | id|temperature| + +---+-----------+ + | 0| 123| + | 0| 23| + | 1| 33| + | 1| 188| + | 1| 88| + +---+-----------+ + + >>> df.groupBy("value").transformWithStateInPandas(statefulProcessor = + ... SimpleStatefulProcessor(), outputStructType=output_schema, outputMode="Update", + ... timeMode="None") # doctest: +SKIP + + Output DataFrame: + +---+-----+ + | id|count| + +---+-----+ + | 0| 2| + | 1| 3| + +---+-----+ + + Notes + ----- + This function requires a full shuffle. + + This API is experimental. + """ + + from pyspark.sql import GroupedData + from pyspark.sql.functions import pandas_udf + + assert isinstance(self, GroupedData) + + def transformWithStateUDF( + statefulProcessorApiClient: StatefulProcessorApiClient, + key: Any, + inputRows: Iterator["PandasDataFrameLike"], + ) -> Iterator["PandasDataFrameLike"]: + handle = StatefulProcessorHandle(statefulProcessorApiClient) + + if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: + statefulProcessor.init(handle) + statefulProcessorApiClient.set_handle_state( + StatefulProcessorHandleState.INITIALIZED + ) + + statefulProcessorApiClient.set_implicit_key(key) + result = statefulProcessor.handleInputRows(key, inputRows) + + return result + + if isinstance(outputStructType, str): + outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) + + udf = pandas_udf( + transformWithStateUDF, # type: ignore + returnType=outputStructType, + functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, + ) + df = self._df + udf_column = udf(*[df[col] for col in df.columns]) + + jdf = self._jgd.transformWithStateInPandas( + udf_column._jc.expr(), + self.session._jsparkSession.parseDataType(outputStructType.json()), + outputMode, + timeMode, + ) + return DataFrame(jdf, self.session) + def applyInArrow( self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str] ) -> "DataFrame": diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 36a1d1f3543d4..6203d4d19d866 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -19,9 +19,16 @@ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. """ +from itertools import groupby from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError from pyspark.loose_version import LooseVersion -from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer +from pyspark.serializers import ( + Serializer, + read_int, + write_int, + UTF8Deserializer, + CPickleSerializer, +) from pyspark.sql.pandas.types import ( from_arrow_type, to_arrow_type, @@ -1116,3 +1123,70 @@ def init_stream_yield_batches(batches): batches_to_write = init_stream_yield_batches(serialize_batches()) return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream) + + +class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): + """ + Serializer used by Python worker to evaluate UDF for + :meth:`pyspark.sql.GroupedData.transformWithStateInPandasSerializer`. + + Parameters + ---------- + timezone : str + A timezone to respect when handling timestamp values + safecheck : bool + If True, conversion from Arrow to Pandas checks for overflow/truncation + assign_cols_by_name : bool + If True, then Pandas DataFrames will get columns by name + arrow_max_records_per_batch : int + Limit of the number of records that can be written to a single ArrowRecordBatch in memory. + """ + + def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch): + super(TransformWithStateInPandasSerializer, self).__init__( + timezone, safecheck, assign_cols_by_name + ) + self.arrow_max_records_per_batch = arrow_max_records_per_batch + self.key_offsets = None + + def load_stream(self, stream): + """ + Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunk, and + convert the data into a list of pandas.Series. + + Please refer the doc of inner function `generate_data_batches` for more details how + this function works in overall. + """ + import pyarrow as pa + + def generate_data_batches(batches): + """ + Deserialize ArrowRecordBatches and return a generator of pandas.Series list. + + The deserialization logic assumes that Arrow RecordBatches contain the data with the + ordering that data chunks for same grouping key will appear sequentially. + + This function must avoid materializing multiple Arrow RecordBatches into memory at the + same time. And data chunks from the same grouping key should appear sequentially. + """ + for batch in batches: + data_pandas = [ + self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns() + ] + key_series = [data_pandas[o] for o in self.key_offsets] + batch_key = tuple(s[0] for s in key_series) + yield (batch_key, data_pandas) + + _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + data_batches = generate_data_batches(_batches) + + for k, g in groupby(data_batches, key=lambda x: x[0]): + yield (k, g) + + def dump_stream(self, iterator, stream): + """ + Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow + RecordBatches, and write batches to stream. + """ + result = [(b, t) for x in iterator for y, t in x for b in y] + super().dump_stream(result, stream) diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.py b/python/pyspark/sql/streaming/StateMessage_pb2.py new file mode 100644 index 0000000000000..0f096e16d47ad --- /dev/null +++ b/python/pyspark/sql/streaming/StateMessage_pb2.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: StateMessage.proto +# Protobuf Python Version: 5.27.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"z\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"5\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501 +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2", _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals["_HANDLESTATE"]._serialized_start = 1873 + _globals["_HANDLESTATE"]._serialized_end = 1948 + _globals["_STATEREQUEST"]._serialized_start = 71 + _globals["_STATEREQUEST"]._serialized_end = 432 + _globals["_STATERESPONSE"]._serialized_start = 434 + _globals["_STATERESPONSE"]._serialized_end = 506 + _globals["_STATEFULPROCESSORCALL"]._serialized_start = 509 + _globals["_STATEFULPROCESSORCALL"]._serialized_end = 902 + _globals["_STATEVARIABLEREQUEST"]._serialized_start = 904 + _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1026 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1029 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1253 + _globals["_STATECALLCOMMAND"]._serialized_start = 1255 + _globals["_STATECALLCOMMAND"]._serialized_end = 1308 + _globals["_VALUESTATECALL"]._serialized_start = 1311 + _globals["_VALUESTATECALL"]._serialized_end = 1664 + _globals["_SETIMPLICITKEY"]._serialized_start = 1666 + _globals["_SETIMPLICITKEY"]._serialized_end = 1695 + _globals["_REMOVEIMPLICITKEY"]._serialized_start = 1697 + _globals["_REMOVEIMPLICITKEY"]._serialized_end = 1716 + _globals["_EXISTS"]._serialized_start = 1718 + _globals["_EXISTS"]._serialized_end = 1726 + _globals["_GET"]._serialized_start = 1728 + _globals["_GET"]._serialized_end = 1733 + _globals["_VALUESTATEUPDATE"]._serialized_start = 1735 + _globals["_VALUESTATEUPDATE"]._serialized_end = 1768 + _globals["_CLEAR"]._serialized_start = 1770 + _globals["_CLEAR"]._serialized_end = 1777 + _globals["_SETHANDLESTATE"]._serialized_start = 1779 + _globals["_SETHANDLESTATE"]._serialized_end = 1871 +# @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.pyi b/python/pyspark/sql/streaming/StateMessage_pb2.pyi new file mode 100644 index 0000000000000..0e6f1fb065881 --- /dev/null +++ b/python/pyspark/sql/streaming/StateMessage_pb2.pyi @@ -0,0 +1,175 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ( + ClassVar as _ClassVar, + Mapping as _Mapping, + Optional as _Optional, + Union as _Union, +) + +DESCRIPTOR: _descriptor.FileDescriptor + +class HandleState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + CREATED: _ClassVar[HandleState] + INITIALIZED: _ClassVar[HandleState] + DATA_PROCESSED: _ClassVar[HandleState] + CLOSED: _ClassVar[HandleState] + +CREATED: HandleState +INITIALIZED: HandleState +DATA_PROCESSED: HandleState +CLOSED: HandleState + +class StateRequest(_message.Message): + __slots__ = ( + "version", + "statefulProcessorCall", + "stateVariableRequest", + "implicitGroupingKeyRequest", + ) + VERSION_FIELD_NUMBER: _ClassVar[int] + STATEFULPROCESSORCALL_FIELD_NUMBER: _ClassVar[int] + STATEVARIABLEREQUEST_FIELD_NUMBER: _ClassVar[int] + IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: _ClassVar[int] + version: int + statefulProcessorCall: StatefulProcessorCall + stateVariableRequest: StateVariableRequest + implicitGroupingKeyRequest: ImplicitGroupingKeyRequest + def __init__( + self, + version: _Optional[int] = ..., + statefulProcessorCall: _Optional[_Union[StatefulProcessorCall, _Mapping]] = ..., + stateVariableRequest: _Optional[_Union[StateVariableRequest, _Mapping]] = ..., + implicitGroupingKeyRequest: _Optional[_Union[ImplicitGroupingKeyRequest, _Mapping]] = ..., + ) -> None: ... + +class StateResponse(_message.Message): + __slots__ = ("statusCode", "errorMessage", "value") + STATUSCODE_FIELD_NUMBER: _ClassVar[int] + ERRORMESSAGE_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + statusCode: int + errorMessage: str + value: bytes + def __init__( + self, statusCode: _Optional[int] = ..., errorMessage: _Optional[str] = ... + ) -> None: ... + +class StatefulProcessorCall(_message.Message): + __slots__ = ("setHandleState", "getValueState", "getListState", "getMapState") + SETHANDLESTATE_FIELD_NUMBER: _ClassVar[int] + GETVALUESTATE_FIELD_NUMBER: _ClassVar[int] + GETLISTSTATE_FIELD_NUMBER: _ClassVar[int] + GETMAPSTATE_FIELD_NUMBER: _ClassVar[int] + setHandleState: SetHandleState + getValueState: StateCallCommand + getListState: StateCallCommand + getMapState: StateCallCommand + def __init__( + self, + setHandleState: _Optional[_Union[SetHandleState, _Mapping]] = ..., + getValueState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., + getListState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., + getMapState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., + ) -> None: ... + +class StateVariableRequest(_message.Message): + __slots__ = ("valueStateCall",) + VALUESTATECALL_FIELD_NUMBER: _ClassVar[int] + valueStateCall: ValueStateCall + def __init__( + self, valueStateCall: _Optional[_Union[ValueStateCall, _Mapping]] = ... + ) -> None: ... + +class ImplicitGroupingKeyRequest(_message.Message): + __slots__ = ("setImplicitKey", "removeImplicitKey") + SETIMPLICITKEY_FIELD_NUMBER: _ClassVar[int] + REMOVEIMPLICITKEY_FIELD_NUMBER: _ClassVar[int] + setImplicitKey: SetImplicitKey + removeImplicitKey: RemoveImplicitKey + def __init__( + self, + setImplicitKey: _Optional[_Union[SetImplicitKey, _Mapping]] = ..., + removeImplicitKey: _Optional[_Union[RemoveImplicitKey, _Mapping]] = ..., + ) -> None: ... + +class StateCallCommand(_message.Message): + __slots__ = ("stateName", "schema") + STATENAME_FIELD_NUMBER: _ClassVar[int] + SCHEMA_FIELD_NUMBER: _ClassVar[int] + stateName: str + schema: str + def __init__(self, stateName: _Optional[str] = ..., schema: _Optional[str] = ...) -> None: ... + +class ValueStateCall(_message.Message): + __slots__ = ("stateName", "exists", "get", "valueStateUpdate", "clear") + STATENAME_FIELD_NUMBER: _ClassVar[int] + EXISTS_FIELD_NUMBER: _ClassVar[int] + GET_FIELD_NUMBER: _ClassVar[int] + VALUESTATEUPDATE_FIELD_NUMBER: _ClassVar[int] + CLEAR_FIELD_NUMBER: _ClassVar[int] + stateName: str + exists: Exists + get: Get + valueStateUpdate: ValueStateUpdate + clear: Clear + def __init__( + self, + stateName: _Optional[str] = ..., + exists: _Optional[_Union[Exists, _Mapping]] = ..., + get: _Optional[_Union[Get, _Mapping]] = ..., + valueStateUpdate: _Optional[_Union[ValueStateUpdate, _Mapping]] = ..., + clear: _Optional[_Union[Clear, _Mapping]] = ..., + ) -> None: ... + +class SetImplicitKey(_message.Message): + __slots__ = ("key",) + KEY_FIELD_NUMBER: _ClassVar[int] + key: bytes + def __init__(self, key: _Optional[bytes] = ...) -> None: ... + +class RemoveImplicitKey(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class Exists(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class Get(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class ValueStateUpdate(_message.Message): + __slots__ = ("value",) + VALUE_FIELD_NUMBER: _ClassVar[int] + value: bytes + def __init__(self, value: _Optional[bytes] = ...) -> None: ... + +class Clear(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class SetHandleState(_message.Message): + __slots__ = ("state",) + STATE_FIELD_NUMBER: _ClassVar[int] + state: HandleState + def __init__(self, state: _Optional[_Union[HandleState, str]] = ...) -> None: ... diff --git a/python/pyspark/sql/streaming/__init__.py b/python/pyspark/sql/streaming/__init__.py index 6b5723d3a3d01..e3c6ca519ad02 100644 --- a/python/pyspark/sql/streaming/__init__.py +++ b/python/pyspark/sql/streaming/__init__.py @@ -18,4 +18,8 @@ from pyspark.sql.streaming.query import StreamingQuery, StreamingQueryManager # noqa: F401 from pyspark.sql.streaming.readwriter import DataStreamReader, DataStreamWriter # noqa: F401 from pyspark.sql.streaming.listener import StreamingQueryListener # noqa: F401 +from pyspark.sql.streaming.stateful_processor import ( # noqa: F401 + StatefulProcessor, # noqa: F401 + StatefulProcessorHandle, # noqa: F401 +) # noqa: F401 from pyspark.errors import StreamingQueryException # noqa: F401 diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py new file mode 100644 index 0000000000000..a378eec2b6175 --- /dev/null +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC, abstractmethod +from typing import Any, TYPE_CHECKING, Iterator, Optional, Union, cast + +from pyspark.sql import Row +from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient +from pyspark.sql.streaming.value_state_client import ValueStateClient +from pyspark.sql.types import StructType, _create_row, _parse_datatype_string + +if TYPE_CHECKING: + from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike + +__all__ = ["StatefulProcessor", "StatefulProcessorHandle"] + + +class ValueState: + """ + Class used for arbitrary stateful operations with transformWithState to capture single value + state. + + .. versionadded:: 4.0.0 + """ + + def __init__( + self, value_state_client: ValueStateClient, state_name: str, schema: Union[StructType, str] + ) -> None: + self._value_state_client = value_state_client + self._state_name = state_name + self.schema = schema + + def exists(self) -> bool: + """ + Whether state exists or not. + """ + return self._value_state_client.exists(self._state_name) + + def get(self) -> Optional[Row]: + """ + Get the state value if it exists. Returns None if the state variable does not have a value. + """ + value = self._value_state_client.get(self._state_name) + if value is None: + return None + schema = self.schema + if isinstance(schema, str): + schema = cast(StructType, _parse_datatype_string(schema)) + # Create the Row using the values and schema fields + row = _create_row(schema.fieldNames(), value) + return row + + def update(self, new_value: Any) -> None: + """ + Update the value of the state. + """ + self._value_state_client.update(self._state_name, self.schema, new_value) + + def clear(self) -> None: + """ + Remove this state. + """ + self._value_state_client.clear(self._state_name) + + +class StatefulProcessorHandle: + """ + Represents the operation handle provided to the stateful processor used in transformWithState + API. + + .. versionadded:: 4.0.0 + """ + + def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient) -> None: + self.stateful_processor_api_client = stateful_processor_api_client + + def getValueState(self, state_name: str, schema: Union[StructType, str]) -> ValueState: + """ + Function to create new or return existing single value state variable of given type. + The user must ensure to call this function only within the `init()` method of the + :class:`StatefulProcessor`. + + Parameters + ---------- + state_name : str + name of the state variable + schema : :class:`pyspark.sql.types.DataType` or str + The schema of the state variable. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + """ + self.stateful_processor_api_client.get_value_state(state_name, schema) + return ValueState(ValueStateClient(self.stateful_processor_api_client), state_name, schema) + + +class StatefulProcessor(ABC): + """ + Class that represents the arbitrary stateful logic that needs to be provided by the user to + perform stateful manipulations on keyed streams. + + .. versionadded:: 4.0.0 + """ + + @abstractmethod + def init(self, handle: StatefulProcessorHandle) -> None: + """ + Function that will be invoked as the first method that allows for users to initialize all + their state variables and perform other init actions before handling data. + + Parameters + ---------- + handle : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessorHandle` + Handle to the stateful processor that provides access to the state store and other + stateful processing related APIs. + """ + ... + + @abstractmethod + def handleInputRows( + self, key: Any, rows: Iterator["PandasDataFrameLike"] + ) -> Iterator["PandasDataFrameLike"]: + """ + Function that will allow users to interact with input data rows along with the grouping key. + It should take parameters (key, Iterator[`pandas.DataFrame`]) and return another + Iterator[`pandas.DataFrame`]. For each group, all columns are passed together as + `pandas.DataFrame` to the function, and the returned `pandas.DataFrame` across all + invocations are combined as a :class:`DataFrame`. Note that the function should not make a + guess of the number of elements in the iterator. To process all data, the `handleInputRows` + function needs to iterate all elements and process them. On the other hand, the + `handleInputRows` function is not strictly required to iterate through all elements in the + iterator if it intends to read a part of data. + + Parameters + ---------- + key : Any + grouping key. + rows : iterable of :class:`pandas.DataFrame` + iterator of input rows associated with grouping key + """ + ... + + @abstractmethod + def close(self) -> None: + """ + Function called as the last method that allows for users to perform any cleanup or teardown + operations. + """ + ... diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py new file mode 100644 index 0000000000000..080d7739992ec --- /dev/null +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -0,0 +1,166 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from enum import Enum +import os +import socket +from typing import Any, Union, cast, Tuple + +from pyspark.serializers import write_int, read_int, UTF8Deserializer +from pyspark.sql.types import StructType, _parse_datatype_string, Row +from pyspark.sql.utils import has_numpy +from pyspark.serializers import CPickleSerializer +from pyspark.errors import PySparkRuntimeError + +__all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"] + + +class StatefulProcessorHandleState(Enum): + CREATED = 1 + INITIALIZED = 2 + DATA_PROCESSED = 3 + CLOSED = 4 + + +class StatefulProcessorApiClient: + def __init__(self, state_server_port: int, key_schema: StructType) -> None: + self.key_schema = key_schema + self._client_socket = socket.socket() + self._client_socket.connect(("localhost", state_server_port)) + self.sockfile = self._client_socket.makefile( + "rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536)) + ) + self.handle_state = StatefulProcessorHandleState.CREATED + self.utf8_deserializer = UTF8Deserializer() + self.pickleSer = CPickleSerializer() + + def set_handle_state(self, state: StatefulProcessorHandleState) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if state == StatefulProcessorHandleState.CREATED: + proto_state = stateMessage.CREATED + elif state == StatefulProcessorHandleState.INITIALIZED: + proto_state = stateMessage.INITIALIZED + elif state == StatefulProcessorHandleState.DATA_PROCESSED: + proto_state = stateMessage.DATA_PROCESSED + else: + proto_state = stateMessage.CLOSED + set_handle_state = stateMessage.SetHandleState(state=proto_state) + handle_call = stateMessage.StatefulProcessorCall(setHandleState=set_handle_state) + message = stateMessage.StateRequest(statefulProcessorCall=handle_call) + + self._send_proto_message(message.SerializeToString()) + + response_message = self._receive_proto_message() + status = response_message[0] + if status == 0: + self.handle_state = state + else: + # TODO(SPARK-49233): Classify errors thrown by internal methods. + raise PySparkRuntimeError(f"Error setting handle state: " f"{response_message[1]}") + + def set_implicit_key(self, key: Tuple) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + key_bytes = self._serialize_to_bytes(self.key_schema, key) + set_implicit_key = stateMessage.SetImplicitKey(key=key_bytes) + request = stateMessage.ImplicitGroupingKeyRequest(setImplicitKey=set_implicit_key) + message = stateMessage.StateRequest(implicitGroupingKeyRequest=request) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify errors thrown by internal methods. + raise PySparkRuntimeError(f"Error setting implicit key: " f"{response_message[1]}") + + def remove_implicit_key(self) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + remove_implicit_key = stateMessage.RemoveImplicitKey() + request = stateMessage.ImplicitGroupingKeyRequest(removeImplicitKey=remove_implicit_key) + message = stateMessage.StateRequest(implicitGroupingKeyRequest=request) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify errors thrown by internal methods. + raise PySparkRuntimeError(f"Error removing implicit key: " f"{response_message[1]}") + + def get_value_state(self, state_name: str, schema: Union[StructType, str]) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if isinstance(schema, str): + schema = cast(StructType, _parse_datatype_string(schema)) + + state_call_command = stateMessage.StateCallCommand() + state_call_command.stateName = state_name + state_call_command.schema = schema.json() + call = stateMessage.StatefulProcessorCall(getValueState=state_call_command) + message = stateMessage.StateRequest(statefulProcessorCall=call) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error initializing value state: " f"{response_message[1]}") + + def _send_proto_message(self, message: bytes) -> None: + # Writing zero here to indicate message version. This allows us to evolve the message + # format or even changing the message protocol in the future. + write_int(0, self.sockfile) + write_int(len(message), self.sockfile) + self.sockfile.write(message) + self.sockfile.flush() + + def _receive_proto_message(self) -> Tuple[int, str, bytes]: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + length = read_int(self.sockfile) + bytes = self.sockfile.read(length) + message = stateMessage.StateResponse() + message.ParseFromString(bytes) + return message.statusCode, message.errorMessage, message.value + + def _receive_str(self) -> str: + return self.utf8_deserializer.loads(self.sockfile) + + def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes: + converted = [] + if has_numpy: + import numpy as np + + # In order to convert NumPy types to Python primitive types. + for v in data: + if isinstance(v, np.generic): + converted.append(v.tolist()) + # Address a couple of pandas dtypes too. + elif hasattr(v, "to_pytimedelta"): + converted.append(v.to_pytimedelta()) + elif hasattr(v, "to_pydatetime"): + converted.append(v.to_pydatetime()) + else: + converted.append(v) + else: + converted = list(data) + + row_value = Row(*converted) + return self.pickleSer.dumps(schema.toInternal(row_value)) + + def _deserialize_from_bytes(self, value: bytes) -> Any: + return self.pickleSer.loads(value) diff --git a/python/pyspark/sql/streaming/value_state_client.py b/python/pyspark/sql/streaming/value_state_client.py new file mode 100644 index 0000000000000..e902f70cb40a5 --- /dev/null +++ b/python/pyspark/sql/streaming/value_state_client.py @@ -0,0 +1,105 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any, Union, cast, Tuple + +from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient +from pyspark.sql.types import StructType, _parse_datatype_string +from pyspark.errors import PySparkRuntimeError + +__all__ = ["ValueStateClient"] + + +class ValueStateClient: + def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient) -> None: + self._stateful_processor_api_client = stateful_processor_api_client + + def exists(self, state_name: str) -> bool: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + exists_call = stateMessage.Exists() + value_state_call = stateMessage.ValueStateCall(stateName=state_name, exists=exists_call) + state_variable_request = stateMessage.StateVariableRequest(valueStateCall=value_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status == 0: + return True + elif status == 2: + # Expect status code is 2 when state variable doesn't have a value. + return False + else: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError( + f"Error checking value state exists: " f"{response_message[1]}" + ) + + def get(self, state_name: str) -> Any: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + get_call = stateMessage.Get() + value_state_call = stateMessage.ValueStateCall(stateName=state_name, get=get_call) + state_variable_request = stateMessage.StateVariableRequest(valueStateCall=value_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status == 0: + if len(response_message[2]) == 0: + return None + row = self._stateful_processor_api_client._deserialize_from_bytes(response_message[2]) + return row + else: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error getting value state: " f"{response_message[1]}") + + def update(self, state_name: str, schema: Union[StructType, str], value: Tuple) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if isinstance(schema, str): + schema = cast(StructType, _parse_datatype_string(schema)) + bytes = self._stateful_processor_api_client._serialize_to_bytes(schema, value) + update_call = stateMessage.ValueStateUpdate(value=bytes) + value_state_call = stateMessage.ValueStateCall( + stateName=state_name, valueStateUpdate=update_call + ) + state_variable_request = stateMessage.StateVariableRequest(valueStateCall=value_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}") + + def clear(self, state_name: str) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + clear_call = stateMessage.Clear() + value_state_call = stateMessage.ValueStateCall(stateName=state_name, clear=clear_call) + state_variable_request = stateMessage.StateVariableRequest(valueStateCall=value_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error clearing value state: " f"{response_message[1]}") diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py new file mode 100644 index 0000000000000..f05a601094a5d --- /dev/null +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -0,0 +1,280 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import tempfile +from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle +from typing import Iterator + +import unittest +from typing import cast + +from pyspark import SparkConf +from pyspark.sql.functions import split +from pyspark.sql.types import ( + StringType, + StructType, + StructField, + Row, + IntegerType, +) +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + +if have_pandas: + import pandas as pd + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), +) +class TransformWithStateInPandasTestsMixin: + @classmethod + def conf(cls): + cfg = SparkConf() + cfg.set("spark.sql.shuffle.partitions", "5") + cfg.set( + "spark.sql.streaming.stateStore.providerClass", + "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider", + ) + return cfg + + def _prepare_test_resource1(self, input_path): + with open(input_path + "/text-test1.txt", "w") as fw: + fw.write("0, 123\n") + fw.write("0, 46\n") + fw.write("1, 146\n") + fw.write("1, 346\n") + + def _prepare_test_resource2(self, input_path): + with open(input_path + "/text-test2.txt", "w") as fw: + fw.write("0, 123\n") + fw.write("0, 223\n") + fw.write("0, 323\n") + fw.write("1, 246\n") + fw.write("1, 6\n") + + def _build_test_df(self, input_path): + df = self.spark.readStream.format("text").option("maxFilesPerTrigger", 1).load(input_path) + df_split = df.withColumn("split_values", split(df["value"], ",")) + df_final = df_split.select( + df_split.split_values.getItem(0).alias("id").cast("string"), + df_split.split_values.getItem(1).alias("temperature").cast("int"), + ) + return df_final + + def _test_transform_with_state_in_pandas_basic( + self, stateful_processor, check_results, single_batch=False + ): + input_path = tempfile.mkdtemp() + self._prepare_test_resource1(input_path) + if not single_batch: + self._prepare_test_resource2(input_path) + + df = self._build_test_df(input_path) + + for q in self.spark.streams.active: + q.stop() + self.assertTrue(df.isStreaming) + + output_schema = StructType( + [ + StructField("id", StringType(), True), + StructField("countAsString", StringType(), True), + ] + ) + + q = ( + df.groupBy("id") + .transformWithStateInPandas( + statefulProcessor=stateful_processor, + outputStructType=output_schema, + outputMode="Update", + timeMode="None", + ) + .writeStream.queryName("this_query") + .foreachBatch(check_results) + .outputMode("update") + .start() + ) + + self.assertEqual(q.name, "this_query") + self.assertTrue(q.isActive) + q.processAllAvailable() + q.awaitTermination(10) + self.assertTrue(q.exception() is None) + + def test_transform_with_state_in_pandas_basic(self): + def check_results(batch_df, batch_id): + if batch_id == 0: + assert set(batch_df.sort("id").collect()) == { + Row(id="0", countAsString="2"), + Row(id="1", countAsString="2"), + } + else: + assert set(batch_df.sort("id").collect()) == { + Row(id="0", countAsString="3"), + Row(id="1", countAsString="2"), + } + + self._test_transform_with_state_in_pandas_basic(SimpleStatefulProcessor(), check_results) + + def test_transform_with_state_in_pandas_non_exist_value_state(self): + def check_results(batch_df, _): + assert set(batch_df.sort("id").collect()) == { + Row(id="0", countAsString="0"), + Row(id="1", countAsString="0"), + } + + self._test_transform_with_state_in_pandas_basic( + InvalidSimpleStatefulProcessor(), check_results, True + ) + + def test_transform_with_state_in_pandas_query_restarts(self): + root_path = tempfile.mkdtemp() + input_path = root_path + "/input" + os.makedirs(input_path, exist_ok=True) + checkpoint_path = root_path + "/checkpoint" + output_path = root_path + "/output" + + self._prepare_test_resource1(input_path) + + df = self._build_test_df(input_path) + + for q in self.spark.streams.active: + q.stop() + self.assertTrue(df.isStreaming) + + output_schema = StructType( + [ + StructField("id", StringType(), True), + StructField("countAsString", StringType(), True), + ] + ) + + base_query = ( + df.groupBy("id") + .transformWithStateInPandas( + statefulProcessor=SimpleStatefulProcessor(), + outputStructType=output_schema, + outputMode="Update", + timeMode="None", + ) + .writeStream.queryName("this_query") + .format("parquet") + .outputMode("append") + .option("checkpointLocation", checkpoint_path) + .option("path", output_path) + ) + q = base_query.start() + self.assertEqual(q.name, "this_query") + self.assertTrue(q.isActive) + q.processAllAvailable() + q.awaitTermination(10) + self.assertTrue(q.exception() is None) + + q.stop() + + self._prepare_test_resource2(input_path) + + q = base_query.start() + self.assertEqual(q.name, "this_query") + self.assertTrue(q.isActive) + q.processAllAvailable() + q.awaitTermination(10) + self.assertTrue(q.exception() is None) + result_df = self.spark.read.parquet(output_path) + assert set(result_df.sort("id").collect()) == { + Row(id="0", countAsString="2"), + Row(id="0", countAsString="3"), + Row(id="1", countAsString="2"), + Row(id="1", countAsString="2"), + } + + +class SimpleStatefulProcessor(StatefulProcessor): + dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}} + batch_id = 0 + + def init(self, handle: StatefulProcessorHandle) -> None: + state_schema = StructType([StructField("value", IntegerType(), True)]) + self.num_violations_state = handle.getValueState("numViolations", state_schema) + + def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]: + new_violations = 0 + count = 0 + key_str = key[0] + exists = self.num_violations_state.exists() + if exists: + existing_violations_row = self.num_violations_state.get() + existing_violations = existing_violations_row[0] + assert existing_violations == self.dict[0][key_str] + self.batch_id = 1 + else: + existing_violations = 0 + for pdf in rows: + pdf_count = pdf.count() + count += pdf_count.get("temperature") + violations_pdf = pdf.loc[pdf["temperature"] > 100] + new_violations += violations_pdf.count().get("temperature") + updated_violations = new_violations + existing_violations + assert updated_violations == self.dict[self.batch_id][key_str] + self.num_violations_state.update((updated_violations,)) + yield pd.DataFrame({"id": key, "countAsString": str(count)}) + + def close(self) -> None: + pass + + +class InvalidSimpleStatefulProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + state_schema = StructType([StructField("value", IntegerType(), True)]) + self.num_violations_state = handle.getValueState("numViolations", state_schema) + + def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]: + count = 0 + exists = self.num_violations_state.exists() + assert not exists + # try to get a state variable with no value + assert self.num_violations_state.get() is None + self.num_violations_state.clear() + yield pd.DataFrame({"id": key, "countAsString": str(count)}) + + def close(self) -> None: + pass + + +class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.pandas.test_pandas_transform_with_state import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 3b2340e405869..205e3d957a415 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -60,6 +60,7 @@ PandasGroupedMapUDFWithStateType, ArrowGroupedMapUDFType, ArrowCogroupedMapUDFType, + PandasGroupedMapUDFTransformWithStateType, ) from pyspark.sql._typing import ( SQLArrowBatchedUDFType, @@ -585,6 +586,7 @@ class PythonEvalType: SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208 SQL_GROUPED_MAP_ARROW_UDF: "ArrowGroupedMapUDFType" = 209 SQL_COGROUPED_MAP_ARROW_UDF: "ArrowCogroupedMapUDFType" = 210 + SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: "PandasGroupedMapUDFTransformWithStateType" = 211 SQL_TABLE_UDF: "SQLTableUDFType" = 300 SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301 diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a4668ae475bd1..b8263769c28a9 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,6 +32,7 @@ _accumulatorRegistry, _deserialize_accumulator, ) +from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.resource import ResourceInformation from pyspark.util import PythonEvalType, local_connect_and_auth @@ -55,6 +56,7 @@ ArrowStreamUDFSerializer, ArrowStreamGroupUDFSerializer, ApplyInPandasWithStateSerializer, + TransformWithStateInPandasSerializer, ) from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import ( @@ -488,6 +490,21 @@ def wrapped(key_series, value_series): return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] +def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf): + def wrapped(stateful_processor_api_client, key, value_series_gen): + import pandas as pd + + values = (pd.concat(x, axis=1) for x in value_series_gen) + result_iter = f(stateful_processor_api_client, key, values) + + # TODO(SPARK-49100): add verification that elements in result_iter are + # indeed of type pd.DataFrame and confirm to assigned cols + + return result_iter + + return lambda p, k, v: [(wrapped(p, k, v), to_arrow_type(return_type))] + + def wrap_grouped_map_pandas_udf_with_state(f, return_type): """ Provides a new lambda instance wrapping user function of applyInPandasWithState. @@ -832,6 +849,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil return args_offsets, wrap_grouped_map_arrow_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: return args_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) + elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: + return args_offsets, wrap_grouped_transform_with_state_pandas_udf( + func, return_type, runner_conf + ) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf) @@ -1404,6 +1425,8 @@ def mapper(_, it): def read_udfs(pickleSer, infile, eval_type): runner_conf = {} + state_server_port = None + key_schema = None if eval_type in ( PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, @@ -1417,6 +1440,7 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, ): # Load conf used for pandas_udf evaluation num_conf = read_int(infile) @@ -1428,6 +1452,9 @@ def read_udfs(pickleSer, infile, eval_type): state_object_schema = None if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) + elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: + state_server_port = read_int(infile) + key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) @@ -1454,6 +1481,16 @@ def read_udfs(pickleSer, infile, eval_type): state_object_schema, arrow_max_records_per_batch, ) + elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: + arrow_max_records_per_batch = runner_conf.get( + "spark.sql.execution.arrow.maxRecordsPerBatch", 10000 + ) + arrow_max_records_per_batch = int(arrow_max_records_per_batch) + + ser = TransformWithStateInPandasSerializer( + timezone, safecheck, _assign_cols_by_name, arrow_max_records_per_batch + ) + elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF: @@ -1609,6 +1646,33 @@ def mapper(a): vals = [a[o] for o in parsed_offsets[0][1]] return f(keys, vals) + elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See TransformWithStateInPandasExec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes + arg_offsets, f = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler + ) + parsed_offsets = extract_key_value_indexes(arg_offsets) + ser.key_offsets = parsed_offsets[0][0] + stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) + + # Create function like this: + # mapper a: f([a[0]], [a[0], a[1]]) + def mapper(a): + key = a[0] + + def values_gen(): + for x in a[1]: + retVal = [x[1][o] for o in parsed_offsets[0][1]] + yield retVal + + # This must be generator comprehension - do not materialize. + return f(stateful_processor_api_client, key, values_gen()) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF: import pyarrow as pa diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala index b6248e97aa3da..4c7c87504ffc4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala @@ -16,6 +16,9 @@ */ package org.apache.spark.sql.catalyst.plans.logical +import java.util.Locale + +import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.streaming.TimeMode /** TimeMode types used in transformWithState operator */ @@ -24,3 +27,19 @@ case object NoTime extends TimeMode case object ProcessingTime extends TimeMode case object EventTime extends TimeMode + +object TimeModes { + def apply(timeMode: String): TimeMode = { + timeMode.toLowerCase(Locale.ROOT) match { + case "none" => + NoTime + case "processingtime" => + ProcessingTime + case "eventtime" => + EventTime + case _ => throw new SparkIllegalArgumentException( + errorClass = "STATEFUL_PROCESSOR_UNKNOWN_TIME_MODE", + messageParameters = Map("timeMode" -> timeMode)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 01d5a1bdea6a4..a662fb4eec962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -21,7 +21,7 @@ import org.apache.spark.resource.ResourceProfile import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, PythonUDTF} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, TimeMode} import org.apache.spark.sql.types.StructType /** @@ -161,6 +161,35 @@ case class FlatMapGroupsInPandasWithState( newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child = newChild) } +/** + * Invokes methods defined in the stateful processor used in arbitrary state API v2. We allow the + * user to act on per-group set of input rows along with keyed state and the user can choose to + * output/return 0 or more rows. For a streaming dataframe, we will repeatedly invoke the interface + * methods for new rows in each trigger and the user's state/state variables will be stored + * persistently across invocations. + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outputAttrs used to define the output rows + * @param outputMode defines the output mode for the statefulProcessor + * @param timeMode the time mode semantics of the stateful processor for timers and TTL. + * @param child logical plan of the underlying data + */ +case class TransformWithStateInPandas( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outputAttrs: Seq[Attribute], + outputMode: OutputMode, + timeMode: TimeMode, + child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = outputAttrs + + override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) + + override protected def withNewChildInternal( + newChild: LogicalPlan): TransformWithStateInPandas = copy(child = newChild) +} + /** * Flatmap cogroups using a udf: iter(pyarrow.RecordBatch) -> iter(pyarrow.RecordBatch) * This is used by DataFrame.groupby().cogroup().applyInArrow(). diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c891763eb4e1a..6ecc11745249a 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -157,6 +157,7 @@ com.google.protobuf protobuf-java + compile org.scalacheck @@ -245,6 +246,32 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + org.apache.maven.plugins + maven-shade-plugin + + false + true + + + com.google.protobuf:* + + + + + com.google + ${spark.shade.packageName}.com.google + + + com.google.common.** + + + + + + + +