diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 0011b62132ade..6b8de0f8ac4ec 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -56,7 +56,7 @@ def get(self) -> Optional[Tuple]: """ return self._value_state_client.get(self._state_name) - def update(self, new_value: Any) -> None: + def update(self, new_value: Tuple) -> None: """ Update the value of the state. """ @@ -156,7 +156,9 @@ def getValueState( self.stateful_processor_api_client.get_value_state(state_name, schema, ttl_duration_ms) return ValueState(ValueStateClient(self.stateful_processor_api_client), state_name, schema) - def getListState(self, state_name: str, schema: Union[StructType, str]) -> ListState: + def getListState( + self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None + ) -> ListState: """ Function to create new or return existing single value state variable of given type. The user must ensure to call this function only within the `init()` method of the @@ -169,8 +171,13 @@ def getListState(self, state_name: str, schema: Union[StructType, str]) -> ListS schema : :class:`pyspark.sql.types.DataType` or str The schema of the state variable. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + ttlDurationMs: int + Time to live duration of the state in milliseconds. State values will not be returned + past ttlDuration and will be eventually removed from the state store. Any state update + resets the expiration time to current processing time plus ttlDuration. + If ttl is not specified the state will never expire. """ - self.stateful_processor_api_client.get_list_state(state_name, schema) + self.stateful_processor_api_client.get_list_state(state_name, schema, ttl_duration_ms) return ListState(ListStateClient(self.stateful_processor_api_client), state_name, schema) diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 2a5e55159e766..449d5a2ad55dc 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -131,7 +131,9 @@ def get_value_state( # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error initializing value state: " f"{response_message[1]}") - def get_list_state(self, state_name: str, schema: Union[StructType, str]) -> None: + def get_list_state( + self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] + ) -> None: import pyspark.sql.streaming.StateMessage_pb2 as stateMessage if isinstance(schema, str): @@ -140,6 +142,8 @@ def get_list_state(self, state_name: str, schema: Union[StructType, str]) -> Non state_call_command = stateMessage.StateCallCommand() state_call_command.stateName = state_name state_call_command.schema = schema.json() + if ttl_duration_ms is not None: + state_call_command.ttl.durationMs = ttl_duration_ms call = stateMessage.StatefulProcessorCall(getListState=state_call_command) message = stateMessage.StateRequest(statefulProcessorCall=call) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 99333ae6f5c26..01cd441941d93 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -221,6 +221,18 @@ def check_results(batch_df, _): self._test_transform_with_state_in_pandas_basic(ListStateProcessor(), check_results, True) + # test list state with ttl has the same behavior as list state when state doesn't expire. + def test_transform_with_state_in_pandas_list_state_large_ttl(self): + def check_results(batch_df, _): + assert set(batch_df.sort("id").collect()) == { + Row(id="0", countAsString="2"), + Row(id="1", countAsString="2"), + } + + self._test_transform_with_state_in_pandas_basic( + ListStateLargeTTLProcessor(), check_results, True, "processingTime" + ) + # test value state with ttl has the same behavior as value state when # state doesn't expire. def test_value_state_ttl_basic(self): @@ -248,8 +260,10 @@ def check_results(batch_df, batch_id): [ Row(id="ttl-count-0", count=1), Row(id="count-0", count=1), + Row(id="ttl-list-state-count-0", count=1), Row(id="ttl-count-1", count=1), Row(id="count-1", count=1), + Row(id="ttl-list-state-count-1", count=1), ], ) elif batch_id == 1: @@ -258,21 +272,29 @@ def check_results(batch_df, batch_id): [ Row(id="ttl-count-0", count=2), Row(id="count-0", count=2), + Row(id="ttl-list-state-count-0", count=3), Row(id="ttl-count-1", count=2), Row(id="count-1", count=2), + Row(id="ttl-list-state-count-1", count=3), ], ) elif batch_id == 2: # ttl-count-0 expire and restart from count 0. - # ttl-count-1 get reset in batch 1 and keep the state + # The TTL for value state ttl_count_state gets reset in batch 1 because of the + # update operation and ttl-count-1 keeps the state. + # ttl-list-state-count-0 expire and restart from count 0. + # The TTL for list state ttl_list_state gets reset in batch 1 because of the + # put operation and ttl-list-state-count-1 keeps the state. # non-ttl state never expires assertDataFrameEqual( batch_df, [ Row(id="ttl-count-0", count=1), Row(id="count-0", count=3), + Row(id="ttl-list-state-count-0", count=1), Row(id="ttl-count-1", count=3), Row(id="count-1", count=3), + Row(id="ttl-list-state-count-1", count=7), ], ) if batch_id == 0 or batch_id == 1: @@ -362,25 +384,38 @@ def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) self.ttl_count_state = handle.getValueState("ttl-state", state_schema, 10000) self.count_state = handle.getValueState("state", state_schema) + self.ttl_list_state = handle.getListState("ttl-list-state", state_schema, 10000) def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]: count = 0 ttl_count = 0 + ttl_list_state_count = 0 id = key[0] if self.count_state.exists(): count = self.count_state.get()[0] if self.ttl_count_state.exists(): ttl_count = self.ttl_count_state.get()[0] + if self.ttl_list_state.exists(): + iter = self.ttl_list_state.get() + for s in iter: + ttl_list_state_count += s[0] for pdf in rows: pdf_count = pdf.count().get("temperature") count += pdf_count ttl_count += pdf_count + ttl_list_state_count += pdf_count self.count_state.update((count,)) # skip updating state for the 2nd batch so that ttl state expire if not (ttl_count == 2 and id == "0"): self.ttl_count_state.update((ttl_count,)) - yield pd.DataFrame({"id": [f"ttl-count-{id}", f"count-{id}"], "count": [ttl_count, count]}) + self.ttl_list_state.put([(ttl_list_state_count,), (ttl_list_state_count,)]) + yield pd.DataFrame( + { + "id": [f"ttl-count-{id}", f"count-{id}", f"ttl-list-state-count-{id}"], + "count": [ttl_count, count, ttl_list_state_count], + } + ) def close(self) -> None: pass @@ -457,6 +492,15 @@ def close(self) -> None: pass +# A stateful processor that inherit all behavior of ListStateProcessor except that it use +# ttl state with a large timeout. +class ListStateLargeTTLProcessor(ListStateProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + state_schema = StructType([StructField("temperature", IntegerType(), True)]) + self.list_state1 = handle.getListState("listState1", state_schema, 30000) + self.list_state2 = handle.getListState("listState2", state_schema, 30000) + + class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase): pass diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index d293e7a4a5bb2..fed1843acfa56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -189,8 +189,12 @@ class TransformWithStateInPandasStateServer( case StatefulProcessorCall.MethodCase.GETLISTSTATE => val stateName = message.getGetListState.getStateName val schema = message.getGetListState.getSchema - // TODO(SPARK-49744): Add ttl support for list state. - initializeStateVariable(stateName, schema, StateVariableType.ListState, None) + val ttlDurationMs = if (message.getGetListState.hasTtl) { + Some(message.getGetListState.getTtl.getDurationMs) + } else { + None + } + initializeStateVariable(stateName, schema, StateVariableType.ListState, ttlDurationMs) case _ => throw new IllegalArgumentException("Invalid method call") } @@ -372,10 +376,14 @@ class TransformWithStateInPandasStateServer( sendResponse(1, s"Value state $stateName already exists") } case StateVariableType.ListState => if (!listStates.contains(stateName)) { - // TODO(SPARK-49744): Add ttl support for list state. + val state = if (ttlDurationMs.isEmpty) { + statefulProcessorHandle.getListState[Row](stateName, Encoders.row(schema)) + } else { + statefulProcessorHandle.getListState( + stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get))) + } listStates.put(stateName, - ListStateInfo(statefulProcessorHandle.getListState[Row](stateName, - Encoders.row(schema)), schema, expressionEncoder.createDeserializer(), + ListStateInfo(state, schema, expressionEncoder.createDeserializer(), expressionEncoder.createSerializer())) sendResponse(0) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala index 137e2531f4f46..776772bb51ca8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala @@ -118,6 +118,29 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo } } + Seq(true, false).foreach { useTTL => + test(s"get list state, useTTL=$useTTL") { + val stateCallCommandBuilder = StateCallCommand.newBuilder() + .setStateName("newName") + .setSchema("StructType(List(StructField(value,IntegerType,true)))") + if (useTTL) { + stateCallCommandBuilder.setTtl(StateMessage.TTLConfig.newBuilder().setDurationMs(1000)) + } + val message = StatefulProcessorCall + .newBuilder() + .setGetListState(stateCallCommandBuilder.build()) + .build() + stateServer.handleStatefulProcessorCall(message) + if (useTTL) { + verify(statefulProcessorHandle) + .getListState[Row](any[String], any[Encoder[Row]], any[TTLConfig]) + } else { + verify(statefulProcessorHandle).getListState[Row](any[String], any[Encoder[Row]]) + } + verify(outputStream).writeInt(0) + } + } + test("value state exists") { val message = ValueStateCall.newBuilder().setStateName(stateName) .setExists(Exists.newBuilder().build()).build()