Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into scripting-for-loop
Browse files Browse the repository at this point in the history
  • Loading branch information
dusantism-db committed Nov 27, 2024
2 parents 9d1cf29 + 6edcf43 commit d4de13a
Show file tree
Hide file tree
Showing 51 changed files with 2,830 additions and 614 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_python_connect.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
# Several tests related to catalog requires to run them sequencially, e.g., writing a table in a listener.
./python/run-tests --parallelism=1 --python-executables=python3 --modules pyspark-connect,pyspark-ml-connect
# None of tests are dependent on each other in Pandas API on Spark so run them in parallel
./python/run-tests --parallelism=2 --python-executables=python3 --modules pyspark-pandas-connect-part0,pyspark-pandas-connect-part1,pyspark-pandas-connect-part2,pyspark-pandas-connect-part3
./python/run-tests --parallelism=1 --python-executables=python3 --modules pyspark-pandas-connect-part0,pyspark-pandas-connect-part1,pyspark-pandas-connect-part2,pyspark-pandas-connect-part3
# Stop Spark Connect server.
./sbin/stop-connect-server.sh
Expand Down
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5381,6 +5381,11 @@
"SQL Scripting is under development and not all features are supported. SQL Scripting enables users to write procedural SQL including control flow and error handling. To enable existing features set <sqlScriptingEnabled> to `true`."
]
},
"SQL_SCRIPTING_WITH_POSITIONAL_PARAMETERS" : {
"message" : [
"Positional parameters are not supported with SQL Scripting."
]
},
"STATE_STORE_MULTIPLE_COLUMN_FAMILIES" : {
"message" : [
"Creating multiple column families with <stateStoreProvider> is not supported."
Expand Down
4 changes: 1 addition & 3 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2982,9 +2982,7 @@ private[spark] object Utils
if (props == null) {
return props
}
val resultProps = new Properties()
props.forEach((k, v) => resultProps.put(k, v))
resultProps
props.clone().asInstanceOf[Properties]
}

/**
Expand Down
5 changes: 5 additions & 0 deletions python/docs/source/user_guide/sql/python_data_source.rst
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,8 @@ The following example demonstrates how to implement a basic Data Source using Ar
df = spark.read.format("arrowbatch").load()
df.show()
Usage Notes
-----------

- During Data Source resolution, built-in and Scala/Java Data Sources take precedence over Python Data Sources with the same name; to explicitly use a Python Data Source, make sure its name does not conflict with the other Data Sources.
22 changes: 14 additions & 8 deletions python/pyspark/sql/connect/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,36 @@
import os
from typing import Optional

__all__ = [
"getLogLevel",
]
__all__ = ["configureLogging", "getLogLevel"]


def _configure_logging() -> logging.Logger:
"""Configure logging for the Spark Connect clients."""
def configureLogging(level: Optional[str] = None) -> logging.Logger:
"""
Configure log level for Spark Connect components.
When not specified as a parameter, log level will be configured based on
the SPARK_CONNECT_LOG_LEVEL environment variable.
When both are absent, logging is disabled.
.. versionadded:: 4.0.0
"""
logger = PySparkLogger.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter(fmt="%(asctime)s %(process)d %(levelname)s %(funcName)s %(message)s")
)
logger.addHandler(handler)

# Check the environment variables for log levels:
if "SPARK_CONNECT_LOG_LEVEL" in os.environ:
if level is not None:
logger.setLevel(level.upper())
elif "SPARK_CONNECT_LOG_LEVEL" in os.environ:
logger.setLevel(os.environ["SPARK_CONNECT_LOG_LEVEL"].upper())
else:
logger.disabled = True
return logger


# Instantiate the logger based on the environment configuration.
logger = _configure_logging()
logger = configureLogging()


def getLogLevel() -> Optional[int]:
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def transformWithStateInPandas(
outputMode: str,
timeMode: str,
initialState: Optional["GroupedData"] = None,
eventTimeColumnName: str = "",
) -> DataFrame:
"""
Invokes methods defined in the stateful processor used in arbitrary state API v2. It
Expand Down Expand Up @@ -662,6 +663,7 @@ def transformWithStateWithInitStateUDF(
outputMode,
timeMode,
initial_state_java_obj,
eventTimeColumnName,
)
return DataFrame(jdf, self.session)

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/connect/test_parity_udf_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from pyspark.sql.tests.test_udf_profiler import (
UDFProfiler2TestsMixin,
_do_computation,
has_flameprof,
)
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.utils import have_flameprof


class UDFProfilerParityTests(UDFProfiler2TestsMixin, ReusedConnectTestCase):
Expand Down Expand Up @@ -65,7 +65,7 @@ def action(df):
io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}"
)

if has_flameprof:
if have_flameprof:
self.assertIn("svg", self.spark.profile.render(id))


Expand Down
158 changes: 138 additions & 20 deletions python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,7 @@
from pyspark import SparkConf
from pyspark.errors import PySparkRuntimeError
from pyspark.sql.functions import split
from pyspark.sql.types import (
StringType,
StructType,
StructField,
Row,
IntegerType,
)
from pyspark.sql.types import StringType, StructType, StructField, Row, IntegerType, TimestampType
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
Expand Down Expand Up @@ -247,11 +241,15 @@ def check_results(batch_df, _):

# test list state with ttl has the same behavior as list state when state doesn't expire.
def test_transform_with_state_in_pandas_list_state_large_ttl(self):
def check_results(batch_df, _):
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
}
def check_results(batch_df, batch_id):
if batch_id == 0:
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
}
else:
for q in self.spark.streams.active:
q.stop()

self._test_transform_with_state_in_pandas_basic(
ListStateLargeTTLProcessor(), check_results, True, "processingTime"
Expand All @@ -268,11 +266,15 @@ def check_results(batch_df, _):

# test map state with ttl has the same behavior as map state when state doesn't expire.
def test_transform_with_state_in_pandas_map_state_large_ttl(self):
def check_results(batch_df, _):
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
}
def check_results(batch_df, batch_id):
if batch_id == 0:
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
}
else:
for q in self.spark.streams.active:
q.stop()

self._test_transform_with_state_in_pandas_basic(
MapStateLargeTTLProcessor(), check_results, True, "processingTime"
Expand All @@ -287,11 +289,14 @@ def check_results(batch_df, batch_id):
Row(id="0", countAsString="2"),
Row(id="1", countAsString="2"),
}
else:
elif batch_id == 1:
assert set(batch_df.sort("id").collect()) == {
Row(id="0", countAsString="3"),
Row(id="1", countAsString="2"),
}
else:
for q in self.spark.streams.active:
q.stop()

self._test_transform_with_state_in_pandas_basic(
SimpleTTLStatefulProcessor(), check_results, False, "processingTime"
Expand Down Expand Up @@ -348,6 +353,9 @@ def check_results(batch_df, batch_id):
Row(id="ttl-map-state-count-1", count=3),
],
)
else:
for q in self.spark.streams.active:
q.stop()
if batch_id == 0 or batch_id == 1:
time.sleep(6)

Expand Down Expand Up @@ -466,7 +474,7 @@ def check_results(batch_df, batch_id):
).first()["timeValues"]
check_timestamp(batch_df)

else:
elif batch_id == 2:
assert set(batch_df.sort("id").select("id", "countAsString").collect()) == {
Row(id="0", countAsString="3"),
Row(id="0", countAsString="-1"),
Expand All @@ -480,6 +488,10 @@ def check_results(batch_df, batch_id):
).first()["timeValues"]
assert current_batch_expired_timestamp > self.first_expired_timestamp

else:
for q in self.spark.streams.active:
q.stop()

self._test_transform_with_state_in_pandas_proc_timer(
ProcTimeStatefulProcessor(), check_results
)
Expand Down Expand Up @@ -552,12 +564,15 @@ def check_results(batch_df, batch_id):
Row(id="a", timestamp="20"),
Row(id="a-expired", timestamp="0"),
}
else:
elif batch_id == 2:
# verify that rows and expired timer produce the expected result
assert set(batch_df.sort("id").collect()) == {
Row(id="a", timestamp="15"),
Row(id="a-expired", timestamp="10000"),
}
else:
for q in self.spark.streams.active:
q.stop()

self._test_transform_with_state_in_pandas_event_time(
EventTimeStatefulProcessor(), check_results
Expand Down Expand Up @@ -679,6 +694,9 @@ def check_results(batch_df, batch_id):
Row(id1="0", id2="1", value=str(123 + 46)),
Row(id1="1", id2="2", value=str(146 + 346)),
}
else:
for q in self.spark.streams.active:
q.stop()

self._test_transform_with_state_non_contiguous_grouping_cols(
SimpleStatefulProcessorWithInitialState(), check_results
Expand All @@ -692,6 +710,9 @@ def check_results(batch_df, batch_id):
Row(id1="0", id2="1", value=str(789 + 123 + 46)),
Row(id1="1", id2="2", value=str(146 + 346)),
}
else:
for q in self.spark.streams.active:
q.stop()

# grouping key of initial state is also not starting from the beginning of attributes
data = [(789, "0", "1"), (987, "3", "2")]
Expand All @@ -703,6 +724,88 @@ def check_results(batch_df, batch_id):
SimpleStatefulProcessorWithInitialState(), check_results, initial_state
)

def _test_transform_with_state_in_pandas_chaining_ops(
self, stateful_processor, check_results, timeMode="None", grouping_cols=["outputTimestamp"]
):
import pyspark.sql.functions as f

input_path = tempfile.mkdtemp()
self._prepare_input_data(input_path + "/text-test3.txt", ["a", "b"], [10, 15])
time.sleep(2)
self._prepare_input_data(input_path + "/text-test4.txt", ["a", "c"], [11, 25])
time.sleep(2)
self._prepare_input_data(input_path + "/text-test1.txt", ["a"], [5])

df = self._build_test_df(input_path)
df = df.select(
"id", f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")
).withWatermark("eventTime", "5 seconds")

for q in self.spark.streams.active:
q.stop()
self.assertTrue(df.isStreaming)

output_schema = StructType(
[
StructField("id", StringType(), True),
StructField("outputTimestamp", TimestampType(), True),
]
)

q = (
df.groupBy("id")
.transformWithStateInPandas(
statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Append",
timeMode=timeMode,
eventTimeColumnName="outputTimestamp",
)
.groupBy(grouping_cols)
.count()
.writeStream.queryName("chaining_ops_query")
.foreachBatch(check_results)
.outputMode("append")
.start()
)

self.assertEqual(q.name, "chaining_ops_query")
self.assertTrue(q.isActive)
q.processAllAvailable()
q.awaitTermination(10)

def test_transform_with_state_in_pandas_chaining_ops(self):
def check_results(batch_df, batch_id):
import datetime

if batch_id == 0:
assert batch_df.isEmpty()
elif batch_id == 1:
# eviction watermark = 15 - 5 = 10 (max event time from batch 0),
# late event watermark = 0 (eviction event time from batch 0)
assert set(
batch_df.sort("outputTimestamp").select("outputTimestamp", "count").collect()
) == {
Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0, 10), count=1),
}
elif batch_id == 2:
# eviction watermark = 25 - 5 = 20, late event watermark = 10;
# row with watermark=5<10 is dropped so it does not show up in the results;
# row with eventTime<=20 are finalized and emitted
assert set(
batch_df.sort("outputTimestamp").select("outputTimestamp", "count").collect()
) == {
Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0, 11), count=1),
Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0, 15), count=1),
}

self._test_transform_with_state_in_pandas_chaining_ops(
StatefulProcessorChainingOps(), check_results, "eventTime"
)
self._test_transform_with_state_in_pandas_chaining_ops(
StatefulProcessorChainingOps(), check_results, "eventTime", ["outputTimestamp", "id"]
)


class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
# this dict is the same as input initial state dataframe
Expand Down Expand Up @@ -888,6 +991,21 @@ def close(self) -> None:
pass


class StatefulProcessorChainingOps(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
pass

def handleInputRows(
self, key, rows, timer_values, expired_timer_info
) -> Iterator[pd.DataFrame]:
for pdf in rows:
timestamp_list = pdf["eventTime"].tolist()
yield pd.DataFrame({"id": key, "outputTimestamp": timestamp_list[0]})

def close(self) -> None:
pass


# A stateful processor that inherit all behavior of SimpleStatefulProcessor except that it use
# ttl state with a large timeout.
class SimpleTTLStatefulProcessor(SimpleStatefulProcessor, unittest.TestCase):
Expand Down
Loading

0 comments on commit d4de13a

Please sign in to comment.