Skip to content

Commit

Permalink
[SPARK-41255][CONNECT] Rename RemoteSparkSession
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
For better source compatibility, this PR changes the type name of RemoteSparkSession to SparkSession and follows the same builder pattern. The communication with the GRPC endpoint is kept in `client.py` whereas the public facing Spark Session related functionality is implemented in `SparkSession` in `session.py`.

The new class does not support the full behavior of the existing Spark Session. To connect to Spark Connect using the new code use the following example:

```
# Connection to a remote endpoint
SparkSession.builder.remote("sc://endpoint/;config=abc").getOrCreate()
```

or

```
# Local connection to a locally running server
SparkSession.builder.remote().getOrCreate()
```

or

```
SparkSession.builder.conf("spark.connect.location", "sc://endpoint").getOrCreate()
```

### Why are the changes needed?
Compatibility.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
UT

Closes apache#38792 from grundprinzip/SPARK-41255.

Lead-authored-by: Martin Grund <[email protected]>
Co-authored-by: Martin Grund <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
  • Loading branch information
2 people authored and hvanhovell committed Nov 25, 2022
1 parent da71626 commit 77e2d45
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 128 deletions.
57 changes: 5 additions & 52 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,19 @@
# limitations under the License.
#


import logging
import os
import urllib.parse
import uuid
from typing import Iterable, Optional, Any, Union, List, Tuple, Dict

import grpc # type: ignore
import pyarrow as pa
import pandas
import pyarrow as pa

import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
import pyspark.sql.types
from pyspark import cloudpickle
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.plan import SQL, Range
from pyspark.sql.types import (
DataType,
ByteType,
Expand All @@ -56,10 +52,6 @@
NullType,
)

from typing import Iterable, Optional, Any, Union, List, Tuple, Dict

logging.basicConfig(level=logging.INFO)


class ChannelBuilder:
"""
Expand Down Expand Up @@ -294,12 +286,12 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
)


class RemoteSparkSession(object):
class SparkConnectClient(object):
"""Conceptually the remote spark session that communicates with the server"""

def __init__(self, connectionString: str = "sc://localhost", userId: Optional[str] = None):
def __init__(self, connectionString: str, userId: Optional[str] = None):
"""
Creates a new RemoteSparkSession for the Spark Connect interface.
Creates a new SparkSession for the Spark Connect interface.
Parameters
----------
Expand All @@ -325,9 +317,6 @@ def __init__(self, connectionString: str = "sc://localhost", userId: Optional[st
self._channel = self._builder.toChannel()
self._stub = grpc_lib.SparkConnectServiceStub(self._channel)

# Create the reader
self.read = DataFrameReader(self)

def register_udf(
self, function: Any, return_type: Union[str, pyspark.sql.types.DataType]
) -> str:
Expand Down Expand Up @@ -355,42 +344,6 @@ def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[Pla
for x in metrics.metrics
]

def sql(self, sql_string: str) -> "DataFrame":
return DataFrame.withPlan(SQL(sql_string), self)

def range(
self,
start: int,
end: int,
step: int = 1,
numPartitions: Optional[int] = None,
) -> DataFrame:
"""
Create a :class:`DataFrame` with column named ``id`` and typed Long,
containing elements in a range from ``start`` to ``end`` (exclusive) with
step value ``step``.
.. versionadded:: 3.4.0
Parameters
----------
start : int
the start value
end : int
the end value (exclusive)
step : int, optional
the incremental step (default: 1)
numPartitions : int, optional
the number of partitions of the DataFrame
Returns
-------
:class:`DataFrame`
"""
return DataFrame.withPlan(
Range(start=start, end=end, step=step, num_partitions=numPartitions), self
)

def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame":
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pyspark.sql.connect.proto as proto

if TYPE_CHECKING:
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.client import SparkConnectClient
import pyspark.sql.connect.proto as proto


Expand Down Expand Up @@ -80,7 +80,7 @@ def __eq__(self, other: Any) -> "Expression": # type: ignore[override]
def __init__(self) -> None:
pass

def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
...

def __str__(self) -> str:
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(self, parent: Expression, alias: list[str], metadata: Any):
self._metadata = metadata
self._parent = parent

def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
if len(self._alias) == 1:
exp = proto.Expression()
exp.alias.name.append(self._alias[0])
Expand Down Expand Up @@ -162,7 +162,7 @@ def __init__(self, value: Any) -> None:
super().__init__()
self._value = value

def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
"""Converts the literal expression to the literal in proto.
TODO(SPARK-40533) This method always assumes the largest type and can thus
Expand Down Expand Up @@ -250,7 +250,7 @@ def name(self) -> str:
"""Returns the qualified name of the column reference."""
return self._unparsed_identifier

def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
"""Returns the Proto representation of the expression."""
expr = proto.Expression()
expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier
Expand All @@ -275,7 +275,7 @@ def __init__(self, expr: str) -> None:
super().__init__()
self._expr: str = expr

def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
"""Returns the Proto representation of the SQL expression."""
expr = proto.Expression()
expr.expression_string.expression = self._expr
Expand All @@ -292,7 +292,7 @@ def __init__(self, col: Column, ascending: bool = True, nullsLast: bool = True)
def __str__(self) -> str:
return str(self.ref) + " ASC" if self.ascending else " DESC"

def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
return self.ref.to_plan(session)


Expand All @@ -306,7 +306,7 @@ def __init__(
self._args = args
self._op = op

def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
fun = proto.Expression()
fun.unresolved_function.parts.append(self._op)
fun.unresolved_function.arguments.extend([x.to_plan(session) for x in self._args])
Expand Down
52 changes: 26 additions & 26 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString, LiteralType
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.session import SparkSession


class GroupedData(object):
Expand Down Expand Up @@ -97,20 +97,20 @@ class DataFrame(object):

def __init__(
self,
session: "RemoteSparkSession",
session: "SparkSession",
data: Optional[List[Any]] = None,
schema: Optional[StructType] = None,
):
"""Creates a new data frame"""
self._schema = schema
self._plan: Optional[plan.LogicalPlan] = None
self._session: "RemoteSparkSession" = session
self._session: "SparkSession" = session

def __repr__(self) -> str:
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))

@classmethod
def withPlan(cls, plan: plan.LogicalPlan, session: "RemoteSparkSession") -> "DataFrame":
def withPlan(cls, plan: plan.LogicalPlan, session: "SparkSession") -> "DataFrame":
"""Main initialization method used to construct a new data frame with a child plan."""
new_frame = DataFrame(session=session)
new_frame._plan = plan
Expand Down Expand Up @@ -197,14 +197,14 @@ def columns(self) -> List[str]:

return self.schema.names

def sparkSession(self) -> "RemoteSparkSession":
def sparkSession(self) -> "SparkSession":
"""Returns Spark session that created this :class:`DataFrame`.
.. versionadded:: 3.4.0
Returns
-------
:class:`RemoteSparkSession`
:class:`SparkSession`
"""
return self._session

Expand Down Expand Up @@ -796,8 +796,8 @@ def toPandas(self) -> "pandas.DataFrame":
raise Exception("Cannot collect on empty plan.")
if self._session is None:
raise Exception("Cannot collect on empty session.")
query = self._plan.to_proto(self._session)
return self._session._to_pandas(query)
query = self._plan.to_proto(self._session.client)
return self._session.client._to_pandas(query)

@property
def schema(self) -> StructType:
Expand All @@ -811,10 +811,10 @@ def schema(self) -> StructType:
"""
if self._schema is None:
if self._plan is not None:
query = self._plan.to_proto(self._session)
query = self._plan.to_proto(self._session.client)
if self._session is None:
raise Exception("Cannot analyze without RemoteSparkSession.")
self._schema = self._session.schema(query)
raise Exception("Cannot analyze without SparkSession.")
self._schema = self._session.client.schema(query)
return self._schema
else:
raise Exception("Empty plan.")
Expand All @@ -834,8 +834,8 @@ def isLocal(self) -> bool:
"""
if self._plan is None:
raise Exception("Cannot analyze on empty plan.")
query = self._plan.to_proto(self._session)
return self._session._analyze(query).is_local
query = self._plan.to_proto(self._session.client)
return self._session.client._analyze(query).is_local

@property
def isStreaming(self) -> bool:
Expand All @@ -859,14 +859,14 @@ def isStreaming(self) -> bool:
"""
if self._plan is None:
raise Exception("Cannot analyze on empty plan.")
query = self._plan.to_proto(self._session)
return self._session._analyze(query).is_streaming
query = self._plan.to_proto(self._session.client)
return self._session.client._analyze(query).is_streaming

def _tree_string(self) -> str:
if self._plan is None:
raise Exception("Cannot analyze on empty plan.")
query = self._plan.to_proto(self._session)
return self._session._analyze(query).tree_string
query = self._plan.to_proto(self._session.client)
return self._session.client._analyze(query).tree_string

def printSchema(self) -> None:
"""Prints out the schema in the tree format.
Expand Down Expand Up @@ -895,8 +895,8 @@ def inputFiles(self) -> List[str]:
"""
if self._plan is None:
raise Exception("Cannot analyze on empty plan.")
query = self._plan.to_proto(self._session)
return self._session._analyze(query).input_files
query = self._plan.to_proto(self._session.client)
return self._session.client._analyze(query).input_files

def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame":
"""Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations.
Expand Down Expand Up @@ -1011,10 +1011,10 @@ def explain(
explain_mode = cast(str, extended)

if self._plan is not None:
query = self._plan.to_proto(self._session)
query = self._plan.to_proto(self._session.client)
if self._session is None:
raise Exception("Cannot analyze without RemoteSparkSession.")
return self._session.explain_string(query, explain_mode)
raise Exception("Cannot analyze without SparkSession.")
return self._session.client.explain_string(query, explain_mode)
else:
return ""

Expand All @@ -1032,8 +1032,8 @@ def createGlobalTempView(self, name: str) -> None:
"""
command = plan.CreateView(
child=self._plan, name=name, is_global=True, replace=False
).command(session=self._session)
self._session.execute_command(command)
).command(session=self._session.client)
self._session.client.execute_command(command)

def createOrReplaceGlobalTempView(self, name: str) -> None:
"""Creates or replaces a global temporary view using the given name.
Expand All @@ -1049,8 +1049,8 @@ def createOrReplaceGlobalTempView(self, name: str) -> None:
"""
command = plan.CreateView(
child=self._plan, name=name, is_global=True, replace=True
).command(session=self._session)
self._session.execute_command(command)
).command(session=self._session.client)
self._session.client.execute_command(command)

def rdd(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("RDD Support for Spark Connect is not implemented.")
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/function_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
FunctionBuilderCallable,
UserDefinedFunctionCallable,
)
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.client import SparkConnectClient


def _build(name: str, *args: "ExpressionOrString") -> ScalarFunctionExpression:
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(
self._args = []
self._func_name = None

def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
if session is None:
raise Exception("CAnnot create UDF without remote Session.")
# Needs to materialize the UDF to the server
Expand Down
Loading

0 comments on commit 77e2d45

Please sign in to comment.