Skip to content

Commit

Permalink
[SPARK-48561][PS][CONNECT] Throw PandasNotImplementedError for unsu…
Browse files Browse the repository at this point in the history
…pported plotting functions

### What changes were proposed in this pull request?
Throw `PandasNotImplementedError` for unsupported plotting functions:
- {Frame, Series}.plot.hist
- {Frame, Series}.plot.kde
- {Frame, Series}.plot.density
- {Frame, Series}.plot(kind="hist", ...)
- {Frame, Series}.plot(kind="hist", ...)
- {Frame, Series}.plot(kind="density", ...)

### Why are the changes needed?
the previous error message is confusing:
```
In [3]: psdf.plot.hist()
/Users/ruifeng.zheng/Dev/spark/python/pyspark/pandas/utils.py:1017: PandasAPIOnSparkAdviceWarning: The config 'spark.sql.ansi.enabled' is set to True. This can cause unexpected behavior from pandas API on Spark since pandas API on Spark follows the behavior of pandas, not SQL.
  warnings.warn(message, PandasAPIOnSparkAdviceWarning)
[*********************************************-----------------------------------] 57.14% Complete (0 Tasks running, 1s, Scanned[*********************************************-----------------------------------] 57.14% Complete (0 Tasks running, 1s, Scanned[*********************************************-----------------------------------] 57.14% Complete (0 Tasks running, 1s, Scanned                                                                                                                                ---------------------------------------------------------------------------
PySparkAttributeError                     Traceback (most recent call last)
Cell In[3], line 1
----> 1 psdf.plot.hist()

File ~/Dev/spark/python/pyspark/pandas/plot/core.py:951, in PandasOnSparkPlotAccessor.hist(self, bins, **kwds)
    903 def hist(self, bins=10, **kwds):
    904     """
    905     Draw one histogram of the DataFrame’s columns.
    906     A `histogram`_ is a representation of the distribution of data.
   (...)
    949         >>> df.plot.hist(bins=12, alpha=0.5)  # doctest: +SKIP
    950     """
--> 951     return self(kind="hist", bins=bins, **kwds)

File ~/Dev/spark/python/pyspark/pandas/plot/core.py:580, in PandasOnSparkPlotAccessor.__call__(self, kind, backend, **kwargs)
    577 kind = {"density": "kde"}.get(kind, kind)
    578 if hasattr(plot_backend, "plot_pandas_on_spark"):
    579     # use if there's pandas-on-Spark specific method.
--> 580     return plot_backend.plot_pandas_on_spark(plot_data, kind=kind, **kwargs)
    581 else:
    582     # fallback to use pandas'
    583     if not PandasOnSparkPlotAccessor.pandas_plot_data_map[kind]:

File ~/Dev/spark/python/pyspark/pandas/plot/plotly.py:41, in plot_pandas_on_spark(data, kind, **kwargs)
     39     return plot_pie(data, **kwargs)
     40 if kind == "hist":
---> 41     return plot_histogram(data, **kwargs)
     42 if kind == "box":
     43     return plot_box(data, **kwargs)

File ~/Dev/spark/python/pyspark/pandas/plot/plotly.py:87, in plot_histogram(data, **kwargs)
     85 psdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
     86 assert len(bins) > 2, "the number of buckets must be higher than 2."
---> 87 output_series = HistogramPlotBase.compute_hist(psdf, bins)
     88 prev = float("%.9f" % bins[0])  # to make it prettier, truncate.
     89 text_bins = []

File ~/Dev/spark/python/pyspark/pandas/plot/core.py:189, in HistogramPlotBase.compute_hist(psdf, bins)
    183 for group_id, (colname, bucket_name) in enumerate(zip(colnames, bucket_names)):
    184     # creates a Bucketizer to get corresponding bin of each value
    185     bucketizer = Bucketizer(
    186         splits=bins, inputCol=colname, outputCol=bucket_name, handleInvalid="skip"
    187     )
--> 189     bucket_df = bucketizer.transform(sdf)
    191     if output_df is None:
    192         output_df = bucket_df.select(
    193             F.lit(group_id).alias("__group_id"), F.col(bucket_name).alias("__bucket")
    194         )

File ~/Dev/spark/python/pyspark/ml/base.py:260, in Transformer.transform(self, dataset, params)
    258         return self.copy(params)._transform(dataset)
    259     else:
--> 260         return self._transform(dataset)
    261 else:
    262     raise TypeError("Params must be a param map but got %s." % type(params))

File ~/Dev/spark/python/pyspark/ml/wrapper.py:412, in JavaTransformer._transform(self, dataset)
    409 assert self._java_obj is not None
    411 self._transfer_params_to_java()
--> 412 return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sparkSession)

File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:1696, in DataFrame.__getattr__(self, name)
   1694 def __getattr__(self, name: str) -> "Column":
   1695     if name in ["_jseq", "_jdf", "_jmap", "_jcols", "rdd", "toJSON"]:
-> 1696         raise PySparkAttributeError(
   1697             error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": name}
   1698         )
   1700     if name not in self.columns:
   1701         raise PySparkAttributeError(
   1702             error_class="ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": name}
   1703         )

PySparkAttributeError: [JVM_ATTRIBUTE_NOT_SUPPORTED] Attribute `_jdf` is not supported in Spark Connect as it depends on the JVM. If you need to use this attribute, do not use Spark Connect when creating your session. Visit https://spark.apache.org/docs/latest/sql-getting-started.html#starting-point-sparksession for creating regular Spark Session in detail.
```

after this PR:
```
In [3]: psdf.plot.hist()
---------------------------------------------------------------------------
PandasNotImplementedError                 Traceback (most recent call last)
Cell In[3], line 1
----> 1 psdf.plot.hist()

File ~/Dev/spark/python/pyspark/pandas/plot/core.py:957, in PandasOnSparkPlotAccessor.hist(self, bins, **kwds)
    909 """
    910 Draw one histogram of the DataFrame’s columns.
    911 A `histogram`_ is a representation of the distribution of data.
   (...)
    954     >>> df.plot.hist(bins=12, alpha=0.5)  # doctest: +SKIP
    955 """
    956 if is_remote():
--> 957     return unsupported_function(class_name="pd.DataFrame", method_name="hist")()
    959 return self(kind="hist", bins=bins, **kwds)

File ~/Dev/spark/python/pyspark/pandas/missing/__init__.py:23, in unsupported_function.<locals>.unsupported_function(*args, **kwargs)
     22 def unsupported_function(*args, **kwargs):
---> 23     raise PandasNotImplementedError(
     24         class_name=class_name, method_name=method_name, reason=reason
     25     )

PandasNotImplementedError: The method `pd.DataFrame.hist()` is not implemented yet.
```

### Does this PR introduce _any_ user-facing change?
yes, error message improvement

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #46911 from zhengruifeng/ps_plotting_unsupported.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jun 7, 2024
1 parent b7d9c31 commit 87b0f59
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 1 deletion.
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,8 @@ def __hash__(self):
"python/pyspark/pandas",
],
python_test_goals=[
# unittests dedicated for Spark Connect
"pyspark.pandas.tests.connect.test_connect_plotting",
# pandas-on-Spark unittests
"pyspark.pandas.tests.connect.test_parity_categorical",
"pyspark.pandas.tests.connect.test_parity_config",
Expand Down
13 changes: 12 additions & 1 deletion python/pyspark/pandas/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pandas.core.dtypes.inference import is_integer

from pyspark.sql import functions as F
from pyspark.sql.utils import is_remote
from pyspark.pandas.missing import unsupported_function
from pyspark.pandas.config import get_option
from pyspark.pandas.utils import name_like_string
Expand Down Expand Up @@ -571,10 +572,14 @@ def _get_plot_backend(backend=None):
return module

def __call__(self, kind="line", backend=None, **kwargs):
kind = {"density": "kde"}.get(kind, kind)

if is_remote() and kind in ["hist", "kde"]:
return unsupported_function(class_name="pd.DataFrame", method_name=kind)()

plot_backend = PandasOnSparkPlotAccessor._get_plot_backend(backend)
plot_data = self.data

kind = {"density": "kde"}.get(kind, kind)
if hasattr(plot_backend, "plot_pandas_on_spark"):
# use if there's pandas-on-Spark specific method.
return plot_backend.plot_pandas_on_spark(plot_data, kind=kind, **kwargs)
Expand Down Expand Up @@ -948,6 +953,9 @@ def hist(self, bins=10, **kwds):
>>> df = ps.from_pandas(df)
>>> df.plot.hist(bins=12, alpha=0.5) # doctest: +SKIP
"""
if is_remote():
return unsupported_function(class_name="pd.DataFrame", method_name="hist")()

return self(kind="hist", bins=bins, **kwds)

def kde(self, bw_method=None, ind=None, **kwargs):
Expand Down Expand Up @@ -1023,6 +1031,9 @@ def kde(self, bw_method=None, ind=None, **kwargs):
... })
>>> df.plot.kde(ind=[1, 2, 3, 4, 5, 6], bw_method=0.3) # doctest: +SKIP
"""
if is_remote():
return unsupported_function(class_name="pd.DataFrame", method_name="kde")()

return self(kind="kde", bw_method=bw_method, ind=ind, **kwargs)

density = kde
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
class SeriesPlotMatplotlibParityTests(
SeriesPlotMatplotlibTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
):
@unittest.skip("Test depends on Spark ML which is not supported from Spark Connect.")
def test_empty_hist(self):
super().test_empty_hist()

@unittest.skip("Test depends on Spark ML which is not supported from Spark Connect.")
def test_hist(self):
super().test_hist()
Expand Down
124 changes: 124 additions & 0 deletions python/pyspark/pandas/tests/connect/test_connect_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

import pandas as pd

from pyspark import pandas as ps
from pyspark.pandas.exceptions import PandasNotImplementedError
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils


class ConnectPlottingTests(PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase):
@property
def pdf1(self):
return pd.DataFrame(
[[1, 2], [4, 5], [7, 8]],
index=["cobra", "viper", None],
columns=["max_speed", "shield"],
)

@property
def psdf1(self):
return ps.from_pandas(self.pdf1)

def test_unsupported_functions(self):
with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot.hist()

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot.hist(bins=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot.kde()

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot.kde(bw_method=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot.density()

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot.density(bw_method=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot.hist()

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot.hist(bins=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot.kde()

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot.kde(bw_method=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot.density()

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot.density(bw_method=3)

def test_unsupported_kinds(self):
with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot(kind="hist")

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot(kind="hist", bins=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot(kind="kde")

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot(kind="kde", bw_method=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot(kind="density")

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot(kind="density", bw_method=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot(kind="hist")

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot(kind="hist", bins=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot(kind="kde")

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot(kind="kde", bw_method=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot(kind="density")

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot(kind="density", bw_method=3)


if __name__ == "__main__":
from pyspark.pandas.tests.connect.test_connect_plotting import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

0 comments on commit 87b0f59

Please sign in to comment.