diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index 392ef73b38845..ed22d02370ca6 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -75,6 +75,8 @@ def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame": class PySparkPlotAccessor: plot_data_map = { + "bar": PySparkTopNPlotBase().get_top_n, + "barh": PySparkTopNPlotBase().get_top_n, "line": PySparkSampledPlotBase().get_sampled, } _backends = {} # type: ignore[var-annotated] @@ -133,3 +135,80 @@ def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP """ return self(kind="line", x=x, y=y, **kwargs) + + def bar(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": + """ + Vertical bar plot. + + A bar plot is a plot that presents categorical data with rectangular bars with lengths + proportional to the values that they represent. A bar plot shows comparisons among + discrete categories. One axis of the plot shows the specific categories being compared, + and the other axis represents a measured value. + + Parameters + ---------- + x : str + Name of column to use for the horizontal axis. + y : str or list of str + Name(s) of the column(s) to use for the vertical axis. + Multiple columns can be plotted. + **kwargs : optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.bar(x="category", y="int_val") # doctest: +SKIP + >>> df.plot.bar(x="category", y=["int_val", "float_val"]) # doctest: +SKIP + """ + return self(kind="bar", x=x, y=y, **kwargs) + + def barh(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": + """ + Make a horizontal bar plot. + + A horizontal bar plot is a plot that presents quantitative data with + rectangular bars with lengths proportional to the values that they + represent. A bar plot shows comparisons among discrete categories. One + axis of the plot shows the specific categories being compared, and the + other axis represents a measured value. + + Parameters + ---------- + x : str or list of str + Name(s) of the column(s) to use for the horizontal axis. + Multiple columns can be plotted. + y : str or list of str + Name(s) of the column(s) to use for the vertical axis. + Multiple columns can be plotted. + **kwargs : optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Notes + ----- + In Plotly and Matplotlib, the interpretation of `x` and `y` for `barh` plots differs. + In Plotly, `x` refers to the values and `y` refers to the categories. + In Matplotlib, `x` refers to the categories and `y` refers to the values. + Ensure correct axis labeling based on the backend used. + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.barh(x="int_val", y="category") # doctest: +SKIP + >>> df.plot.barh( + ... x=["int_val", "float_val"], y="category" + ... ) # doctest: +SKIP + """ + return self(kind="barh", x=x, y=y, **kwargs) diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index 72a3ed267d192..1c52c93a23d3a 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -28,9 +28,16 @@ def sdf(self): columns = ["category", "int_val", "float_val"] return self.spark.createDataFrame(data, columns) - def _check_fig_data(self, fig_data, expected_x, expected_y, expected_name=""): - self.assertEqual(fig_data["mode"], "lines") - self.assertEqual(fig_data["type"], "scatter") + def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=""): + if kind == "line": + self.assertEqual(fig_data["mode"], "lines") + self.assertEqual(fig_data["type"], "scatter") + elif kind == "bar": + self.assertEqual(fig_data["type"], "bar") + elif kind == "barh": + self.assertEqual(fig_data["type"], "bar") + self.assertEqual(fig_data["orientation"], "h") + self.assertEqual(fig_data["xaxis"], "x") self.assertEqual(list(fig_data["x"]), expected_x) self.assertEqual(fig_data["yaxis"], "y") @@ -40,12 +47,37 @@ def _check_fig_data(self, fig_data, expected_x, expected_y, expected_name=""): def test_line_plot(self): # single column as vertical axis fig = self.sdf.plot(kind="line", x="category", y="int_val") - self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20]) + self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20]) # multiple columns as vertical axis fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"]) - self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") - self._check_fig_data(fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") + self._check_fig_data("line", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + + def test_bar_plot(self): + # single column as vertical axis + fig = self.sdf.plot(kind="bar", x="category", y="int_val") + self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20]) + + # multiple columns as vertical axis + fig = self.sdf.plot.bar(x="category", y=["int_val", "float_val"]) + self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") + self._check_fig_data("bar", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + + def test_barh_plot(self): + # single column as vertical axis + fig = self.sdf.plot(kind="barh", x="category", y="int_val") + self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20]) + + # multiple columns as vertical axis + fig = self.sdf.plot.barh(x="category", y=["int_val", "float_val"]) + self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") + self._check_fig_data("barh", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + + # multiple columns as horizontal axis + fig = self.sdf.plot.barh(x=["int_val", "float_val"], y="category") + self._check_fig_data("barh", fig["data"][0], [10, 30, 20], ["A", "B", "C"], "int_val") + self._check_fig_data("barh", fig["data"][1], [1.5, 2.5, 3.5], ["A", "B", "C"], "float_val") class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):