Skip to content

Commit

Permalink
[SPARK-41379][SS][PYTHON] Provide cloned spark session in DataFrame i…
Browse files Browse the repository at this point in the history
…n user function for foreachBatch sink in PySpark

### What changes were proposed in this pull request?

This PR proposes to provide cloned spark session in DataFrame in user function for foreachBatch sink in PySpark.

### Why are the changes needed?

It's arguable a bug - previously given DataFrame is associated with two different SparkSessions, 1) one which runs the streaming query (accessed via `df.sparkSession`) 2) one which microbatch execution "cloned" (accessed via `df._jdf.sparkSession()`). If users pick the 1), it destroys the purpose of cloning spark session, e.g. disabling AQE. Also, which session is picked up depends on the underlying implementation of "each" method in DataFrame, which would give inconsistency.

Following is a problematic example:

```
def user_func(batch_df, batch_id):
  batch_df.createOrReplaceTempView("updates")
  ... # what is the right way to refer the temp view "updates"?
```

Before this PR, the only way to refer the temp view "updates" is, using "internal" field in DataFrame, `_jdf`. That said, running a new query via `batch_df._jdf.sparkSession()` can only see the temp view defined in the user function. We would like to make this possible without enforcing end users to access "internal" field.

After this PR, they can (and should) use `batch_df.sparkSession` instead.

### Does this PR introduce _any_ user-facing change?

Yes, this PR makes in sync to which spark session to use. Users can use df.sparkSession to access cloned spark session, which will be the same with the spark session the methods in DataFrame will use.

### How was this patch tested?

New test case which fails with current master branch.

Closes apache#38906 from HeartSaVioR/SPARK-41379.

Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
HeartSaVioR committed Dec 5, 2022
1 parent 1d16591 commit f4ec6f2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,27 @@ def collectBatch(batch_df, batch_id):
if q:
q.stop()

def test_streaming_foreachBatch_tempview(self):
q = None
collected = dict()

def collectBatch(batch_df, batch_id):
batch_df.createOrReplaceTempView("updates")
# it should use the spark session within given DataFrame, as microbatch execution will
# clone the session which is no longer same with the session used to start the
# streaming query
collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM updates").collect()

try:
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
q = df.writeStream.foreachBatch(collectBatch).start()
q.processAllAvailable()
self.assertTrue(0 in collected)
self.assertTrue(len(collected[0]), 2)
finally:
if q:
q.stop()

def test_streaming_foreachBatch_propagates_python_errors(self):
from pyspark.sql.utils import StreamingQueryException

Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,13 @@ def __init__(self, session: "SparkSession", func: Callable[["DataFrame", int], N

def call(self, jdf: JavaObject, batch_id: int) -> None:
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.session import SparkSession

try:
self.func(DataFrame(jdf, self.session), batch_id)
session_jdf = jdf.sparkSession()
# assuming that spark context is still the same between JVM and PySpark
wrapped_session_jdf = SparkSession(self.session.sparkContext, session_jdf)
self.func(DataFrame(jdf, wrapped_session_jdf), batch_id)
except Exception as e:
self.error = e
raise e
Expand Down

0 comments on commit f4ec6f2

Please sign in to comment.