Skip to content

Commit

Permalink
[SPARK-49626][PYTHON][CONNECT] Support horizontal and vertical bar plots
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support horizontal and vertical bar plots with plotly backend on both Spark Connect and Spark classic.

### Why are the changes needed?
While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments.

See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress.

Part of https://issues.apache.org/jira/browse/SPARK-49530.

### Does this PR introduce _any_ user-facing change?
Yes.

```python
>>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
>>> columns = ["category", "int_val", "float_val"]
>>> sdf = spark.createDataFrame(data, columns)
>>> sdf.show()
+--------+-------+---------+
|category|int_val|float_val|
+--------+-------+---------+
|       A|     10|      1.5|
|       B|     30|      2.5|
|       C|     20|      3.5|
+--------+-------+---------+

>>> f = sdf.plot(kind="bar", x="category", y=["int_val", "float_val"])
>>> f.show()  # see below
>>> g = sdf.plot.barh(x=["int_val", "float_val"], y="category")
>>> g.show()  # see below
```
`f.show()`:
![newplot (4)](https://github.com/user-attachments/assets/0df9ee86-fb48-4796-b6c3-aaf2879217aa)

`g.show()`:
![newplot (3)](https://github.com/user-attachments/assets/f39b01c3-66e6-464b-b2e8-badebb39bc67)

### How was this patch tested?
Unit tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48100 from xinrong-meng/plot_bar.

Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Xinrong Meng <[email protected]>
  • Loading branch information
xinrong-meng committed Sep 23, 2024
1 parent d2e8c1c commit 44ec70f
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 6 deletions.
79 changes: 79 additions & 0 deletions python/pyspark/sql/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame":

class PySparkPlotAccessor:
plot_data_map = {
"bar": PySparkTopNPlotBase().get_top_n,
"barh": PySparkTopNPlotBase().get_top_n,
"line": PySparkSampledPlotBase().get_sampled,
}
_backends = {} # type: ignore[var-annotated]
Expand Down Expand Up @@ -133,3 +135,80 @@ def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure":
>>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP
"""
return self(kind="line", x=x, y=y, **kwargs)

def bar(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure":
"""
Vertical bar plot.
A bar plot is a plot that presents categorical data with rectangular bars with lengths
proportional to the values that they represent. A bar plot shows comparisons among
discrete categories. One axis of the plot shows the specific categories being compared,
and the other axis represents a measured value.
Parameters
----------
x : str
Name of column to use for the horizontal axis.
y : str or list of str
Name(s) of the column(s) to use for the vertical axis.
Multiple columns can be plotted.
**kwargs : optional
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Examples
--------
>>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
>>> columns = ["category", "int_val", "float_val"]
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.bar(x="category", y="int_val") # doctest: +SKIP
>>> df.plot.bar(x="category", y=["int_val", "float_val"]) # doctest: +SKIP
"""
return self(kind="bar", x=x, y=y, **kwargs)

def barh(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure":
"""
Make a horizontal bar plot.
A horizontal bar plot is a plot that presents quantitative data with
rectangular bars with lengths proportional to the values that they
represent. A bar plot shows comparisons among discrete categories. One
axis of the plot shows the specific categories being compared, and the
other axis represents a measured value.
Parameters
----------
x : str or list of str
Name(s) of the column(s) to use for the horizontal axis.
Multiple columns can be plotted.
y : str or list of str
Name(s) of the column(s) to use for the vertical axis.
Multiple columns can be plotted.
**kwargs : optional
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Notes
-----
In Plotly and Matplotlib, the interpretation of `x` and `y` for `barh` plots differs.
In Plotly, `x` refers to the values and `y` refers to the categories.
In Matplotlib, `x` refers to the categories and `y` refers to the values.
Ensure correct axis labeling based on the backend used.
Examples
--------
>>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
>>> columns = ["category", "int_val", "float_val"]
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.barh(x="int_val", y="category") # doctest: +SKIP
>>> df.plot.barh(
... x=["int_val", "float_val"], y="category"
... ) # doctest: +SKIP
"""
return self(kind="barh", x=x, y=y, **kwargs)
44 changes: 38 additions & 6 deletions python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,16 @@ def sdf(self):
columns = ["category", "int_val", "float_val"]
return self.spark.createDataFrame(data, columns)

def _check_fig_data(self, fig_data, expected_x, expected_y, expected_name=""):
self.assertEqual(fig_data["mode"], "lines")
self.assertEqual(fig_data["type"], "scatter")
def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=""):
if kind == "line":
self.assertEqual(fig_data["mode"], "lines")
self.assertEqual(fig_data["type"], "scatter")
elif kind == "bar":
self.assertEqual(fig_data["type"], "bar")
elif kind == "barh":
self.assertEqual(fig_data["type"], "bar")
self.assertEqual(fig_data["orientation"], "h")

self.assertEqual(fig_data["xaxis"], "x")
self.assertEqual(list(fig_data["x"]), expected_x)
self.assertEqual(fig_data["yaxis"], "y")
Expand All @@ -40,12 +47,37 @@ def _check_fig_data(self, fig_data, expected_x, expected_y, expected_name=""):
def test_line_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="line", x="category", y="int_val")
self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20])
self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20])

# multiple columns as vertical axis
fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"])
self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
self._check_fig_data(fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")
self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
self._check_fig_data("line", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")

def test_bar_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="bar", x="category", y="int_val")
self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20])

# multiple columns as vertical axis
fig = self.sdf.plot.bar(x="category", y=["int_val", "float_val"])
self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
self._check_fig_data("bar", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")

def test_barh_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="barh", x="category", y="int_val")
self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20])

# multiple columns as vertical axis
fig = self.sdf.plot.barh(x="category", y=["int_val", "float_val"])
self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
self._check_fig_data("barh", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")

# multiple columns as horizontal axis
fig = self.sdf.plot.barh(x=["int_val", "float_val"], y="category")
self._check_fig_data("barh", fig["data"][0], [10, 30, 20], ["A", "B", "C"], "int_val")
self._check_fig_data("barh", fig["data"][1], [1.5, 2.5, 3.5], ["A", "B", "C"], "float_val")


class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
Expand Down

0 comments on commit 44ec70f

Please sign in to comment.