Skip to content

Commit

Permalink
[SPARK-42859][CONNECT][PS] Basic support for pandas API on Spark Connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes to support pandas API on Spark for Spark Connect. This PR includes minimal changes to support basic functionality of the pandas API in Spark Connect, and sets up a testing environment into `pyspark/pandas/tests/connect` using all existing pandas API on Spark test bases to test the functionality of the pandas API on Spark in a remote Spark session.

Here is a summary of the key tasks:
1. All pandas-on-Spark tests under the `python/pyspark/pandas/tests/` directory can now be performed in Spark Connect by adding corresponding tests to the `python/pyspark/pandas/tests/connect/` directory.
2. Unlike with Spark SQL, we did not create a separate package directory such as `python/pyspark/sql/connect` for Spark Connect, so I modified the existing files of `pyspark.pandas`. This allows users to use the existing pandas-on-Spark code as it is on Spark Connect.
3. Because of 2, I added two typing rules into `python/pyspark/pandas/_typing.py` for addressing both PySpark Column and Spark Connect Column in the single path.
   - Added `GenericColumn` for typing both PySpark Column and Spark Connect Column.
   - Added `GenericDataFrame` for typing both PySpark DataFrame and Spark Connect DataFrame.

### Why are the changes needed?

By supporting the pandas API in Spark Connect, it can significantly improve the usability for existing PySpark and pandas users.

### Does this PR introduce _any_ user-facing change?

No, because it is designed to allow existing code for regular Spark sessions to be used without any user-facing changes other than switching the regular Spark session to remote Spark session. However, since some features of the existing pandas API on Spark are not fully supported yet, some features may be limited.

### How was this patch tested?

A testing bed has been set up to reproduce all existing pandas-on-Spark tests for Spark Connect, ensuring that the existing tests can be replicated in Spark Connect. The current result for all tests as below:

| Test file                                           | Test total | Test passed | Coverage |
| --------------------------------------------------- | ---------- | ----------- | -------- |
| test_parity_dataframe.py                            | 105        | 85          | 80.95%   |
| test_parity_dataframe_slow.py                       | 66         | 48          | 72.73%   |
| test_parity_dataframe_conversion.py                 | 11         | 11          | 100.00%  |
| test_parity_dataframe_spark_io.py                   | 8          | 7           | 87.50%   |
| test_parity_ops_on_diff_frames.py                   | 75         | 75          | 100.00%  |
| test_parity_series.py                               | 131        | 104         | 79.39%   |
| test_parity_series_datetime.py                      | 41         | 34          | 82.93%   |
| test_parity_categorical.py                          | 29         | 22          | 75.86%   |
| test_parity_config.py                               | 7          | 7           | 100.00%  |
| test_parity_csv.py                                  | 18         | 18          | 100.00%  |
| test_parity_default_index.py                        | 4          | 1           | 25.00%   |
| test_parity_ewm.py                                  | 3          | 1           | 33.33%   |
| test_parity_expanding.py                            | 22         | 2           | 9.09%    |
| test_parity_extention.py                            | 7          | 7           | 100.00%  |
| test_parity_frame_spark.py                          | 6          | 2           | 33.33%   |
| test_parity_generic_functions.py                    | 4          | 1           | 25.00%   |
| test_parity_groupby.py                              | 49         | 36          | 73.47%   |
| test_parity_groupby_slow.py                         | 205        | 147         | 71.71%   |
| test_parity_indexing.py                             | 3          | 3           | 100.00%  |
| test_parity_indexops_spark.py                       | 3          | 3           | 100.00%  |
| test_parity_internal.py                             | 1          | 0           | 0.00%    |
| test_parity_namespace.py                            | 29         | 26          | 89.66%   |
| test_parity_numpy_compat.py                         | 6          | 4           | 66.67%   |
| test_parity_ops_on_diff_frames_groupby.py           | 22         | 13          | 59.09%   |
| test_parity_ops_on_diff_frames_groupby_expanding.py | 7          | 0           | 0.00%    |
| test_parity_ops_on_diff_frames_groupby_rolling.py   | 7          | 0           | 0.00%    |
| test_parity_ops_on_diff_frames_slow.py              | 22         | 15          | 68.18%   |
| test_parity_repr.py                                 | 5          | 5           | 100.00%  |
| test_parity_resample.py                             | 5          | 3           | 60.00%   |
| test_parity_reshape.py                              | 10         | 8           | 80.00%   |
| test_parity_rolling.py                              | 21         | 1           | 4.76%    |
| test_parity_scalars.py                              | 1          | 1           | 100.00%  |
| test_parity_series_conversion.py                    | 2          | 2           | 100.00%  |
| test_parity_series_string.py                        | 56         | 55          | 98.21%   |
| test_parity_spark_functions.py                      | 1          | 1           | 100.00%  |
| test_parity_sql.py                                  | 7          | 4           | 57.14%   |
| test_parity_stats.py                                | 15         | 7           | 46.67%   |
| test_parity_typedef.py                              | 10         | 10          | 100.00%  |
| test_parity_utils.py                                | 5          | 5           | 100.00%  |
| test_parity_window.py                               | 2          | 2           | 100.00%  |
| test_parity_frame_plot.py                           | 7          | 5           | 71.43%   |
| plot/test_parity_frame_plot_matplotlib.py           | 13         | 11          | 84.62%   |
| plot/test_parity_frame_plot_plotly.py               | 12         | 9           | 75.00%   |
| plot/test_parity_series_plot.py                     | 3          | 3           | 100.00%  |
| plot/test_parity_series_plot_matplotlib.py          | 14         | 8           | 57.14%   |
| plot/test_parity_series_plot_plotly.py              | 9          | 7           | 77.78%   |
| indexes/test_parity_base.py                         | 144        | 75          | 52.08%   |
| indexes/test_parity_category.py                     | 16         | 7           | 43.75%   |
| indexes/test_parity_datetime.py                     | 13         | 11          | 84.62%   |
| indexes/test_parity_timedelta.py                    | 2          | 1           | 50.00%   |
| data_type_ops/test_parity_base.py                   | 2          | 2           | 100.00%  |
| data_type_ops/test_parity_binary_ops.py             | 30         | 25          | 83.33%   |
| data_type_ops/test_parity_boolean_ops.py            | 31         | 26          | 83.87%   |
| data_type_ops/test_parity_categorical_ops.py        | 30         | 23          | 76.67%   |
| data_type_ops/test_parity_complex_ops.py            | 30         | 30          | 100.00%  |
| data_type_ops/test_parity_date_ops.py               | 30         | 25          | 83.33%   |
| data_type_ops/test_parity_datetime_ops.py           | 30         | 25          | 83.33%   |
| data_type_ops/test_parity_null_ops.py               | 26         | 19          | 73.08%   |
| data_type_ops/test_parity_num_ops.py                | 33         | 25          | 75.76%   |
| data_type_ops/test_parity_string_ops.py             | 30         | 23          | 76.67%   |
| data_type_ops/test_parity_timedelta_ops.py          | 26         | 19          | 73.08%   |
| data_type_ops/test_parity_udf_ops.py                | 26         | 18          | 69.23%   |
| Total                                               | 1588       | 1173        | 73.87%   |

Closes apache#40525 from itholic/initial_pandas_connect.

Authored-by: itholic <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
itholic authored and HyukjinKwon committed Apr 8, 2023
1 parent 0542b94 commit f3edc0c
Show file tree
Hide file tree
Showing 152 changed files with 4,919 additions and 389 deletions.
62 changes: 62 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,68 @@ def __hash__(self):
# ml unittests
"pyspark.ml.tests.connect.test_connect_function",
"pyspark.ml.tests.connect.test_parity_torch_distributor",
# pandas-on-Spark unittests
"pyspark.pandas.tests.connect.data_type_ops.test_parity_base",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_binary_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_boolean_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_categorical_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_complex_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_date_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_datetime_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_null_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_num_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_string_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_udt_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_timedelta_ops",
"pyspark.pandas.tests.connect.indexes.test_parity_category",
"pyspark.pandas.tests.connect.indexes.test_parity_timedelta",
"pyspark.pandas.tests.connect.plot.test_parity_frame_plot",
"pyspark.pandas.tests.connect.plot.test_parity_frame_plot_matplotlib",
"pyspark.pandas.tests.connect.plot.test_parity_frame_plot_plotly",
"pyspark.pandas.tests.connect.plot.test_parity_series_plot",
"pyspark.pandas.tests.connect.plot.test_parity_series_plot_matplotlib",
"pyspark.pandas.tests.connect.plot.test_parity_series_plot_plotly",
"pyspark.pandas.tests.connect.test_parity_categorical",
"pyspark.pandas.tests.connect.test_parity_config",
"pyspark.pandas.tests.connect.test_parity_csv",
"pyspark.pandas.tests.connect.test_parity_dataframe_conversion",
"pyspark.pandas.tests.connect.test_parity_dataframe_spark_io",
"pyspark.pandas.tests.connect.test_parity_default_index",
"pyspark.pandas.tests.connect.test_parity_expanding",
"pyspark.pandas.tests.connect.test_parity_extension",
"pyspark.pandas.tests.connect.test_parity_frame_spark",
"pyspark.pandas.tests.connect.test_parity_generic_functions",
"pyspark.pandas.tests.connect.test_parity_indexops_spark",
"pyspark.pandas.tests.connect.test_parity_internal",
"pyspark.pandas.tests.connect.test_parity_namespace",
"pyspark.pandas.tests.connect.test_parity_numpy_compat",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_expanding",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_rolling",
"pyspark.pandas.tests.connect.test_parity_repr",
"pyspark.pandas.tests.connect.test_parity_resample",
"pyspark.pandas.tests.connect.test_parity_reshape",
"pyspark.pandas.tests.connect.test_parity_rolling",
"pyspark.pandas.tests.connect.test_parity_scalars",
"pyspark.pandas.tests.connect.test_parity_series_conversion",
"pyspark.pandas.tests.connect.test_parity_series_datetime",
"pyspark.pandas.tests.connect.test_parity_series_string",
"pyspark.pandas.tests.connect.test_parity_spark_functions",
"pyspark.pandas.tests.connect.test_parity_sql",
"pyspark.pandas.tests.connect.test_parity_typedef",
"pyspark.pandas.tests.connect.test_parity_utils",
"pyspark.pandas.tests.connect.test_parity_window",
"pyspark.pandas.tests.connect.indexes.test_parity_base",
"pyspark.pandas.tests.connect.indexes.test_parity_datetime",
"pyspark.pandas.tests.connect.test_parity_dataframe",
"pyspark.pandas.tests.connect.test_parity_dataframe_slow",
"pyspark.pandas.tests.connect.test_parity_groupby",
"pyspark.pandas.tests.connect.test_parity_groupby_slow",
"pyspark.pandas.tests.connect.test_parity_indexing",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_slow",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby",
"pyspark.pandas.tests.connect.test_parity_series",
"pyspark.pandas.tests.connect.test_parity_stats",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
1 change: 1 addition & 0 deletions dev/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ per-file-ignores =
python/pyspark/ml/tests/*.py: F403,
python/pyspark/mllib/tests/*.py: F403,
python/pyspark/pandas/tests/*.py: F401 F403,
python/pyspark/pandas/tests/connect/*.py: F401 F403,
python/pyspark/resource/tests/*.py: F403,
python/pyspark/sql/tests/*.py: F403,
python/pyspark/streaming/tests/*.py: F403,
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import numpy as np
from pandas.api.extensions import ExtensionDtype

from pyspark.sql.column import Column as PySparkColumn
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.dataframe import DataFrame as PySparkDataFrame
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame


if TYPE_CHECKING:
from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.frame import DataFrame
Expand Down Expand Up @@ -49,3 +55,7 @@

DataFrameOrSeries = Union["DataFrame", "Series"]
SeriesOrIndex = Union["Series", "Index"]

# For Spark Connect compatibility.
GenericColumn = Union[PySparkColumn, ConnectColumn]
GenericDataFrame = Union[PySparkDataFrame, ConnectDataFrame]
2 changes: 1 addition & 1 deletion python/pyspark/pandas/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def attach_id_column(self, id_type: str, column: Name) -> "DataFrame":
for scol, label in zip(internal.data_spark_columns, internal.column_labels)
]
)
sdf = attach_func(sdf, name_like_string(column))
sdf = attach_func(sdf, name_like_string(column)) # type: ignore[assignment]

return DataFrame(
InternalFrame(
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pyspark.sql.types import LongType, BooleanType, NumericType

from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Axis, Dtype, IndexOpsLike, Label, SeriesOrIndex
from pyspark.pandas._typing import Axis, Dtype, IndexOpsLike, Label, SeriesOrIndex, GenericColumn
from pyspark.pandas.config import get_option, option_context
from pyspark.pandas.internal import (
InternalField,
Expand Down Expand Up @@ -67,7 +67,7 @@ def should_alignment_for_column_op(self: SeriesOrIndex, other: SeriesOrIndex) ->


def align_diff_index_ops(
func: Callable[..., Column], this_index_ops: SeriesOrIndex, *args: Any
func: Callable[..., GenericColumn], this_index_ops: SeriesOrIndex, *args: Any
) -> SeriesOrIndex:
"""
Align the `IndexOpsMixin` objects and apply the function.
Expand Down Expand Up @@ -178,7 +178,7 @@ def align_diff_index_ops(
).rename(that_series.name)


def booleanize_null(scol: Column, f: Callable[..., Column]) -> Column:
def booleanize_null(scol: GenericColumn, f: Callable[..., GenericColumn]) -> GenericColumn:
"""
Booleanize Null in Spark Column
"""
Expand All @@ -190,12 +190,12 @@ def booleanize_null(scol: Column, f: Callable[..., Column]) -> Column:
if f in comp_ops:
# if `f` is "!=", fill null with True otherwise False
filler = f == Column.__ne__
scol = F.when(scol.isNull(), filler).otherwise(scol)
scol = F.when(scol.isNull(), filler).otherwise(scol) # type: ignore[arg-type]

return scol


def column_op(f: Callable[..., Column]) -> Callable[..., SeriesOrIndex]:
def column_op(f: Callable[..., GenericColumn]) -> Callable[..., SeriesOrIndex]:
"""
A decorator that wraps APIs taking/returning Spark Column so that pandas-on-Spark Series can be
supported too. If this decorator is used for the `f` function that takes Spark Column and
Expand Down Expand Up @@ -225,7 +225,7 @@ def wrapper(self: SeriesOrIndex, *args: Any) -> SeriesOrIndex:
)

field = InternalField.from_struct_field(
self._internal.spark_frame.select(scol).schema[0],
self._internal.spark_frame.select(scol).schema[0], # type: ignore[arg-type]
use_extension_dtypes=any(
isinstance(col.dtype, extension_dtypes) for col in [self] + cols
),
Expand All @@ -252,7 +252,7 @@ def wrapper(self: SeriesOrIndex, *args: Any) -> SeriesOrIndex:
return wrapper


def numpy_column_op(f: Callable[..., Column]) -> Callable[..., SeriesOrIndex]:
def numpy_column_op(f: Callable[..., GenericColumn]) -> Callable[..., SeriesOrIndex]:
@wraps(f)
def wrapper(self: SeriesOrIndex, *args: Any) -> SeriesOrIndex:
# PySpark does not support NumPy type out of the box. For now, we convert NumPy types
Expand Down Expand Up @@ -287,7 +287,7 @@ def _psdf(self) -> DataFrame:

@abstractmethod
def _with_new_scol(
self: IndexOpsLike, scol: Column, *, field: Optional[InternalField] = None
self: IndexOpsLike, scol: GenericColumn, *, field: Optional[InternalField] = None
) -> IndexOpsLike:
pass

Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/pandas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,9 @@ def get_option(key: str, default: Union[Any, _NoValueType] = _NoValue) -> Any:
if default is _NoValue:
default = _options_dict[key].default
_options_dict[key].validate(default)
spark_session = default_session()

return json.loads(default_session().conf.get(_key_format(key), default=json.dumps(default)))
return json.loads(spark_session.conf.get(_key_format(key), default=json.dumps(default)))


def set_option(key: str, value: Any) -> None:
Expand All @@ -386,8 +387,9 @@ def set_option(key: str, value: Any) -> None:
"""
_check_option(key)
_options_dict[key].validate(value)
spark_session = default_session()

default_session().conf.set(_key_format(key), json.dumps(value))
spark_session.conf.set(_key_format(key), json.dumps(value))


def reset_option(key: str) -> None:
Expand Down
16 changes: 11 additions & 5 deletions python/pyspark/pandas/data_type_ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import numbers
from abc import ABCMeta
from itertools import chain
from typing import Any, Optional, Union
from typing import cast, Callable, Any, Optional, Union

import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype

from pyspark.sql import functions as F, Column
from pyspark.sql import functions as F, Column as PySparkColumn
from pyspark.sql.types import (
ArrayType,
BinaryType,
Expand All @@ -44,7 +44,7 @@
TimestampNTZType,
UserDefinedType,
)
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex, GenericColumn
from pyspark.pandas.typedef import extension_dtypes
from pyspark.pandas.typedef.typehints import (
extension_dtypes_available,
Expand All @@ -53,6 +53,10 @@
spark_type_to_pandas_dtype,
)

# For supporting Spark Connect
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.utils import is_remote

if extension_dtypes_available:
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype

Expand Down Expand Up @@ -470,14 +474,16 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
else:
from pyspark.pandas.base import column_op

return column_op(Column.__eq__)(left, right)
Column = ConnectColumn if is_remote() else PySparkColumn
return column_op(cast(Callable[..., GenericColumn], Column.__eq__))(left, right)

def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op

_sanitize_list_like(right)

return column_op(Column.__ne__)(left, right)
Column = ConnectColumn if is_remote() else PySparkColumn
return column_op(cast(Callable[..., GenericColumn], Column.__ne__))(left, right)

def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
raise TypeError("Unary ~ can not be applied to %s." % self.pretty_name)
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/pandas/data_type_ops/binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
# limitations under the License.
#

from typing import Any, Union, cast
from typing import Any, Union, cast, Callable

from pandas.api.types import CategoricalDtype

from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex, GenericColumn
from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_categorical_type,
Expand All @@ -46,9 +46,9 @@ def add(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)

if isinstance(right, IndexOpsMixin) and isinstance(right.spark.data_type, BinaryType):
return column_op(F.concat)(left, right)
return column_op(cast(Callable[..., GenericColumn], F.concat))(left, right)
elif isinstance(right, bytes):
return column_op(F.concat)(left, F.lit(right))
return column_op(cast(Callable[..., GenericColumn], F.concat))(left, F.lit(right))
else:
raise TypeError(
"Concatenation can not be applied to %s and the given type." % self.pretty_name
Expand All @@ -71,26 +71,26 @@ def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:

_sanitize_list_like(right)

return column_op(Column.__lt__)(left, right)
return column_op(cast(Callable[..., GenericColumn], Column.__lt__))(left, right)

def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op

_sanitize_list_like(right)

return column_op(Column.__le__)(left, right)
return column_op(cast(Callable[..., GenericColumn], Column.__le__))(left, right)

def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op

_sanitize_list_like(right)
return column_op(Column.__ge__)(left, right)
return column_op(cast(Callable[..., GenericColumn], Column.__ge__))(left, right)

def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op

_sanitize_list_like(right)
return column_op(Column.__gt__)(left, right)
return column_op(cast(Callable[..., GenericColumn], Column.__gt__))(left, right)

def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
dtype, spark_type = pandas_on_spark_type(dtype)
Expand Down
Loading

0 comments on commit f3edc0c

Please sign in to comment.