From 77e2d453b14eca1ab6740e6c532394fc908050f4 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Fri, 25 Nov 2022 18:12:37 -0400 Subject: [PATCH] [SPARK-41255][CONNECT] Rename RemoteSparkSession ### What changes were proposed in this pull request? For better source compatibility, this PR changes the type name of RemoteSparkSession to SparkSession and follows the same builder pattern. The communication with the GRPC endpoint is kept in `client.py` whereas the public facing Spark Session related functionality is implemented in `SparkSession` in `session.py`. The new class does not support the full behavior of the existing Spark Session. To connect to Spark Connect using the new code use the following example: ``` # Connection to a remote endpoint SparkSession.builder.remote("sc://endpoint/;config=abc").getOrCreate() ``` or ``` # Local connection to a locally running server SparkSession.builder.remote().getOrCreate() ``` or ``` SparkSession.builder.conf("spark.connect.location", "sc://endpoint").getOrCreate() ``` ### Why are the changes needed? Compatibility. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38792 from grundprinzip/SPARK-41255. Lead-authored-by: Martin Grund Co-authored-by: Martin Grund Signed-off-by: Herman van Hovell --- python/pyspark/sql/connect/client.py | 57 +--- python/pyspark/sql/connect/column.py | 16 +- python/pyspark/sql/connect/dataframe.py | 52 ++-- .../pyspark/sql/connect/function_builder.py | 4 +- python/pyspark/sql/connect/plan.py | 64 ++--- python/pyspark/sql/connect/readwriter.py | 4 +- python/pyspark/sql/connect/session.py | 258 ++++++++++++++++++ .../sql/tests/connect/test_connect_basic.py | 5 +- .../test_connect_column_expressions.py | 2 +- python/pyspark/testing/connectutils.py | 6 +- 10 files changed, 340 insertions(+), 128 deletions(-) create mode 100644 python/pyspark/sql/connect/session.py diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index b41df12c357c5..a2a0797c49fa2 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -15,23 +15,19 @@ # limitations under the License. # - -import logging import os import urllib.parse import uuid +from typing import Iterable, Optional, Any, Union, List, Tuple, Dict import grpc # type: ignore -import pyarrow as pa import pandas +import pyarrow as pa import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib import pyspark.sql.types from pyspark import cloudpickle -from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.connect.readwriter import DataFrameReader -from pyspark.sql.connect.plan import SQL, Range from pyspark.sql.types import ( DataType, ByteType, @@ -56,10 +52,6 @@ NullType, ) -from typing import Iterable, Optional, Any, Union, List, Tuple, Dict - -logging.basicConfig(level=logging.INFO) - class ChannelBuilder: """ @@ -294,12 +286,12 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult": ) -class RemoteSparkSession(object): +class SparkConnectClient(object): """Conceptually the remote spark session that communicates with the server""" - def __init__(self, connectionString: str = "sc://localhost", userId: Optional[str] = None): + def __init__(self, connectionString: str, userId: Optional[str] = None): """ - Creates a new RemoteSparkSession for the Spark Connect interface. + Creates a new SparkSession for the Spark Connect interface. Parameters ---------- @@ -325,9 +317,6 @@ def __init__(self, connectionString: str = "sc://localhost", userId: Optional[st self._channel = self._builder.toChannel() self._stub = grpc_lib.SparkConnectServiceStub(self._channel) - # Create the reader - self.read = DataFrameReader(self) - def register_udf( self, function: Any, return_type: Union[str, pyspark.sql.types.DataType] ) -> str: @@ -355,42 +344,6 @@ def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[Pla for x in metrics.metrics ] - def sql(self, sql_string: str) -> "DataFrame": - return DataFrame.withPlan(SQL(sql_string), self) - - def range( - self, - start: int, - end: int, - step: int = 1, - numPartitions: Optional[int] = None, - ) -> DataFrame: - """ - Create a :class:`DataFrame` with column named ``id`` and typed Long, - containing elements in a range from ``start`` to ``end`` (exclusive) with - step value ``step``. - - .. versionadded:: 3.4.0 - - Parameters - ---------- - start : int - the start value - end : int - the end value (exclusive) - step : int, optional - the incremental step (default: 1) - numPartitions : int, optional - the number of partitions of the DataFrame - - Returns - ------- - :class:`DataFrame` - """ - return DataFrame.withPlan( - Range(start=start, end=end, step=step, num_partitions=numPartitions), self - ) - def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame": req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 36f38e0ded286..69f9fa72db6db 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -24,7 +24,7 @@ import pyspark.sql.connect.proto as proto if TYPE_CHECKING: - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.client import SparkConnectClient import pyspark.sql.connect.proto as proto @@ -80,7 +80,7 @@ def __eq__(self, other: Any) -> "Expression": # type: ignore[override] def __init__(self) -> None: pass - def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": ... def __str__(self) -> str: @@ -131,7 +131,7 @@ def __init__(self, parent: Expression, alias: list[str], metadata: Any): self._metadata = metadata self._parent = parent - def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": if len(self._alias) == 1: exp = proto.Expression() exp.alias.name.append(self._alias[0]) @@ -162,7 +162,7 @@ def __init__(self, value: Any) -> None: super().__init__() self._value = value - def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": """Converts the literal expression to the literal in proto. TODO(SPARK-40533) This method always assumes the largest type and can thus @@ -250,7 +250,7 @@ def name(self) -> str: """Returns the qualified name of the column reference.""" return self._unparsed_identifier - def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: """Returns the Proto representation of the expression.""" expr = proto.Expression() expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier @@ -275,7 +275,7 @@ def __init__(self, expr: str) -> None: super().__init__() self._expr: str = expr - def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: """Returns the Proto representation of the SQL expression.""" expr = proto.Expression() expr.expression_string.expression = self._expr @@ -292,7 +292,7 @@ def __init__(self, col: Column, ascending: bool = True, nullsLast: bool = True) def __str__(self) -> str: return str(self.ref) + " ASC" if self.ascending else " DESC" - def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return self.ref.to_plan(session) @@ -306,7 +306,7 @@ def __init__( self._args = args self._op = op - def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: fun = proto.Expression() fun.unresolved_function.parts.append(self._op) fun.unresolved_function.arguments.extend([x.to_plan(session) for x in self._args]) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 23340e461658e..6fabab69cf53c 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -46,7 +46,7 @@ if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString, LiteralType - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.session import SparkSession class GroupedData(object): @@ -97,20 +97,20 @@ class DataFrame(object): def __init__( self, - session: "RemoteSparkSession", + session: "SparkSession", data: Optional[List[Any]] = None, schema: Optional[StructType] = None, ): """Creates a new data frame""" self._schema = schema self._plan: Optional[plan.LogicalPlan] = None - self._session: "RemoteSparkSession" = session + self._session: "SparkSession" = session def __repr__(self) -> str: return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @classmethod - def withPlan(cls, plan: plan.LogicalPlan, session: "RemoteSparkSession") -> "DataFrame": + def withPlan(cls, plan: plan.LogicalPlan, session: "SparkSession") -> "DataFrame": """Main initialization method used to construct a new data frame with a child plan.""" new_frame = DataFrame(session=session) new_frame._plan = plan @@ -197,14 +197,14 @@ def columns(self) -> List[str]: return self.schema.names - def sparkSession(self) -> "RemoteSparkSession": + def sparkSession(self) -> "SparkSession": """Returns Spark session that created this :class:`DataFrame`. .. versionadded:: 3.4.0 Returns ------- - :class:`RemoteSparkSession` + :class:`SparkSession` """ return self._session @@ -796,8 +796,8 @@ def toPandas(self) -> "pandas.DataFrame": raise Exception("Cannot collect on empty plan.") if self._session is None: raise Exception("Cannot collect on empty session.") - query = self._plan.to_proto(self._session) - return self._session._to_pandas(query) + query = self._plan.to_proto(self._session.client) + return self._session.client._to_pandas(query) @property def schema(self) -> StructType: @@ -811,10 +811,10 @@ def schema(self) -> StructType: """ if self._schema is None: if self._plan is not None: - query = self._plan.to_proto(self._session) + query = self._plan.to_proto(self._session.client) if self._session is None: - raise Exception("Cannot analyze without RemoteSparkSession.") - self._schema = self._session.schema(query) + raise Exception("Cannot analyze without SparkSession.") + self._schema = self._session.client.schema(query) return self._schema else: raise Exception("Empty plan.") @@ -834,8 +834,8 @@ def isLocal(self) -> bool: """ if self._plan is None: raise Exception("Cannot analyze on empty plan.") - query = self._plan.to_proto(self._session) - return self._session._analyze(query).is_local + query = self._plan.to_proto(self._session.client) + return self._session.client._analyze(query).is_local @property def isStreaming(self) -> bool: @@ -859,14 +859,14 @@ def isStreaming(self) -> bool: """ if self._plan is None: raise Exception("Cannot analyze on empty plan.") - query = self._plan.to_proto(self._session) - return self._session._analyze(query).is_streaming + query = self._plan.to_proto(self._session.client) + return self._session.client._analyze(query).is_streaming def _tree_string(self) -> str: if self._plan is None: raise Exception("Cannot analyze on empty plan.") - query = self._plan.to_proto(self._session) - return self._session._analyze(query).tree_string + query = self._plan.to_proto(self._session.client) + return self._session.client._analyze(query).tree_string def printSchema(self) -> None: """Prints out the schema in the tree format. @@ -895,8 +895,8 @@ def inputFiles(self) -> List[str]: """ if self._plan is None: raise Exception("Cannot analyze on empty plan.") - query = self._plan.to_proto(self._session) - return self._session._analyze(query).input_files + query = self._plan.to_proto(self._session.client) + return self._session.client._analyze(query).input_files def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame": """Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations. @@ -1011,10 +1011,10 @@ def explain( explain_mode = cast(str, extended) if self._plan is not None: - query = self._plan.to_proto(self._session) + query = self._plan.to_proto(self._session.client) if self._session is None: - raise Exception("Cannot analyze without RemoteSparkSession.") - return self._session.explain_string(query, explain_mode) + raise Exception("Cannot analyze without SparkSession.") + return self._session.client.explain_string(query, explain_mode) else: return "" @@ -1032,8 +1032,8 @@ def createGlobalTempView(self, name: str) -> None: """ command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=False - ).command(session=self._session) - self._session.execute_command(command) + ).command(session=self._session.client) + self._session.client.execute_command(command) def createOrReplaceGlobalTempView(self, name: str) -> None: """Creates or replaces a global temporary view using the given name. @@ -1049,8 +1049,8 @@ def createOrReplaceGlobalTempView(self, name: str) -> None: """ command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=True - ).command(session=self._session) - self._session.execute_command(command) + ).command(session=self._session.client) + self._session.client.execute_command(command) def rdd(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("RDD Support for Spark Connect is not implemented.") diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py index 4a2688d6a0daf..8df5e56b452a6 100644 --- a/python/pyspark/sql/connect/function_builder.py +++ b/python/pyspark/sql/connect/function_builder.py @@ -34,7 +34,7 @@ FunctionBuilderCallable, UserDefinedFunctionCallable, ) - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.client import SparkConnectClient def _build(name: str, *args: "ExpressionOrString") -> ScalarFunctionExpression: @@ -91,7 +91,7 @@ def __init__( self._args = [] self._func_name = None - def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: if session is None: raise Exception("CAnnot create UDF without remote Session.") # Needs to materialize the UDF to the server diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 853b1a6dc0e1a..9a22d6ea38ecc 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.client import SparkConnectClient class InputValidationError(Exception): @@ -57,7 +57,7 @@ def unresolved_attr(self, colName: str) -> proto.Expression: return exp def to_attr_or_expression( - self, col: "ColumnOrName", session: "RemoteSparkSession" + self, col: "ColumnOrName", session: "SparkConnectClient" ) -> proto.Expression: """Returns either an instance of an unresolved attribute or the serialized expression value of the column.""" @@ -66,13 +66,13 @@ def to_attr_or_expression( else: return cast(Column, col).to_plan(session) - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: ... - def command(self, session: "RemoteSparkSession") -> proto.Command: + def command(self, session: "SparkConnectClient") -> proto.Command: ... - def _verify(self, session: "RemoteSparkSession") -> bool: + def _verify(self, session: "SparkConnectClient") -> bool: """This method is used to verify that the current logical plan can be serialized to Proto and back and afterwards is identical.""" plan = proto.Plan() @@ -84,13 +84,13 @@ def _verify(self, session: "RemoteSparkSession") -> bool: return test_plan == plan - def to_proto(self, session: "RemoteSparkSession", debug: bool = False) -> proto.Plan: + def to_proto(self, session: "SparkConnectClient", debug: bool = False) -> proto.Plan: """ Generates connect proto plan based on this LogicalPlan. Parameters ---------- - session : :class:`RemoteSparkSession`, optional. + session : :class:`SparkConnectClient`, optional. a session that connects remote spark cluster. debug: bool if enabled, the proto plan will be printed. @@ -127,7 +127,7 @@ def __init__( self.schema = schema self.options = options - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = proto.Relation() if self.format is not None: plan.read.data_source.format = self.format @@ -158,7 +158,7 @@ def __init__(self, table_name: str) -> None: super().__init__(None) self.table_name = table_name - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = proto.Relation() plan.read.named_table.unparsed_identifier = self.table_name return plan @@ -186,7 +186,7 @@ def __init__( self.truncate = truncate self.vertical = vertical - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.show_string.input.CopyFrom(self._child.plan(session)) @@ -242,7 +242,7 @@ def _verify_expressions(self) -> None: f"Only Expressions or String can be used for projections: '{c}'." ) - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None proj_exprs = [] for c in self._raw_columns: @@ -281,7 +281,7 @@ def __init__(self, child: Optional["LogicalPlan"], filter: Expression) -> None: super().__init__(child) self.filter = filter - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.filter.input.CopyFrom(self._child.plan(session)) @@ -309,7 +309,7 @@ def __init__(self, child: Optional["LogicalPlan"], limit: int) -> None: super().__init__(child) self.limit = limit - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.limit.input.CopyFrom(self._child.plan(session)) @@ -337,7 +337,7 @@ def __init__(self, child: Optional["LogicalPlan"], offset: int = 0) -> None: super().__init__(child) self.offset = offset - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.offset.input.CopyFrom(self._child.plan(session)) @@ -371,7 +371,7 @@ def __init__( self.all_columns_as_keys = all_columns_as_keys self.column_names = column_names - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys @@ -411,7 +411,7 @@ def __init__( self.is_global = is_global def col_to_sort_field( - self, col: Union[SortOrder, Column, str], session: "RemoteSparkSession" + self, col: Union[SortOrder, Column, str], session: "SparkConnectClient" ) -> proto.Sort.SortField: if isinstance(col, SortOrder): sf = proto.Sort.SortField() @@ -438,7 +438,7 @@ def col_to_sort_field( sf.nulls = proto.Sort.SortNulls.SORT_NULLS_LAST return sf - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.sort.input.CopyFrom(self._child.plan(session)) @@ -474,7 +474,7 @@ def __init__( self.columns = columns def _convert_to_expr( - self, col: Union[Column, str], session: "RemoteSparkSession" + self, col: Union[Column, str], session: "SparkConnectClient" ) -> proto.Expression: expr = proto.Expression() if isinstance(col, Column): @@ -483,7 +483,7 @@ def _convert_to_expr( expr.CopyFrom(self.unresolved_attr(col)) return expr - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.drop.input.CopyFrom(self._child.plan(session)) @@ -521,7 +521,7 @@ def __init__( self.with_replacement = with_replacement self.seed = seed - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.sample.input.CopyFrom(self._child.plan(session)) @@ -567,12 +567,12 @@ def __init__( self.grouping_cols = grouping_cols self.measures = measures - def _convert_measure(self, m: Expression, session: "RemoteSparkSession") -> proto.Expression: + def _convert_measure(self, m: Expression, session: "SparkConnectClient") -> proto.Expression: proto_expr = proto.Expression() proto_expr.CopyFrom(m.to_plan(session)) return proto_expr - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None groupings = [x.to_plan(session) for x in self.grouping_cols] @@ -642,7 +642,7 @@ def __init__( ) self.how = join_type - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: rel = proto.Relation() rel.join.left.CopyFrom(self.left.plan(session)) rel.join.right.CopyFrom(self.right.plan(session)) @@ -693,7 +693,7 @@ def __init__( self.is_all = is_all self.set_op = set_op - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None rel = proto.Relation() if self._child is not None: @@ -753,7 +753,7 @@ def __init__(self, child: Optional["LogicalPlan"], num_partitions: int, shuffle: self._num_partitions = num_partitions self._shuffle = shuffle - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: rel = proto.Relation() if self._child is not None: rel.repartition.input.CopyFrom(self._child.plan(session)) @@ -786,7 +786,7 @@ def __init__(self, child: Optional["LogicalPlan"], alias: str) -> None: super().__init__(child) self._alias = alias - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: rel = proto.Relation() if self._child is not None: rel.subquery_alias.input.CopyFrom(self._child.plan(session)) @@ -814,7 +814,7 @@ def __init__(self, query: str) -> None: super().__init__(None) self._query = query - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: rel = proto.Relation() rel.sql.query = self._query return rel @@ -849,7 +849,7 @@ def __init__( self._step = step self._num_partitions = num_partitions - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: rel = proto.Relation() rel.range.start = self._start rel.range.end = self._end @@ -912,7 +912,7 @@ def _convert_value(self, v: Any) -> proto.Expression.Literal: value.string = v return value - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.fill_na.input.CopyFrom(self._child.plan(session)) @@ -942,7 +942,7 @@ def __init__(self, child: Optional["LogicalPlan"], statistics: List[str]) -> Non super().__init__(child) self.statistics = statistics - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.summary.input.CopyFrom(self._child.plan(session)) @@ -971,7 +971,7 @@ def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str) -> None self.col1 = col1 self.col2 = col2 - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() @@ -1006,7 +1006,7 @@ def __init__( self._is_gloal = is_global self._replace = replace - def command(self, session: "RemoteSparkSession") -> proto.Command: + def command(self, session: "SparkConnectClient") -> proto.Command: assert self._child is not None plan = proto.Command() diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 27aa023ae474f..ead027c206bdf 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: from pyspark.sql.connect._typing import OptionalPrimitiveType - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.session import SparkSession class DataFrameReader: @@ -34,7 +34,7 @@ class DataFrameReader: TODO(SPARK-40539) Achieve parity with PySpark. """ - def __init__(self, client: "RemoteSparkSession"): + def __init__(self, client: "SparkSession"): self._client = client self._format = "" self._schema = "" diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py new file mode 100644 index 0000000000000..92f58140eaccd --- /dev/null +++ b/python/pyspark/sql/connect/session.py @@ -0,0 +1,258 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from threading import RLock +from typing import Optional, Any, Union, Dict, cast, overload + +import pyspark.sql.types +from pyspark.sql.connect.client import SparkConnectClient +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.connect.plan import SQL, Range +from pyspark.sql.connect.readwriter import DataFrameReader +from pyspark.sql.utils import to_str +from ._typing import OptionalPrimitiveType + + +# TODO(SPARK-38912): This method can be dropped once support for Python 3.8 is dropped +# In Python 3.9, the @property decorator has been made compatible with the +# @classmethod decorator (https://docs.python.org/3.9/library/functions.html#classmethod) +# +# @classmethod + @property is also affected by a bug in Python's docstring which was backported +# to Python 3.9.6 (https://github.com/python/cpython/pull/28838) +class classproperty(property): + """Same as Python's @property decorator, but for class attributes. + + Examples + -------- + >>> class Builder: + ... def build(self): + ... return MyClass() + ... + >>> class MyClass: + ... @classproperty + ... def builder(cls): + ... print("instantiating new builder") + ... return Builder() + ... + >>> c1 = MyClass.builder + instantiating new builder + >>> c2 = MyClass.builder + instantiating new builder + >>> c1 == c2 + False + >>> isinstance(c1.build(), MyClass) + True + """ + + def __get__(self, instance: Any, owner: Any = None) -> "SparkSession.Builder": + # The "type: ignore" below silences the following error from mypy: + # error: Argument 1 to "classmethod" has incompatible + # type "Optional[Callable[[Any], Any]]"; + # expected "Callable[..., Any]" [arg-type] + return classmethod(self.fget).__get__(None, owner)() # type: ignore + + +class SparkSession(object): + """Conceptually the remote spark session that communicates with the server""" + + class Builder: + """Builder for :class:`SparkSession`.""" + + _lock = RLock() + + def __init__(self) -> None: + self._options: Dict[str, Any] = {} + + @overload + def config(self, key: str, value: Any) -> "SparkSession.Builder": + ... + + @overload + def config(self, *, map: Dict[str, "OptionalPrimitiveType"]) -> "SparkSession.Builder": + ... + + def config( + self, + key: Optional[str] = None, + value: Optional[Any] = None, + *, + map: Optional[Dict[str, "OptionalPrimitiveType"]] = None, + ) -> "SparkSession.Builder": + """Sets a config option. Options set using this method are automatically propagated to + both :class:`SparkConf` and :class:`SparkSession`'s own configuration. + + .. versionadded:: 2.0.0 + + Parameters + ---------- + key : str, optional + a key name string for configuration property + value : str, optional + a value for configuration property + map: dictionary, optional + a dictionary of configurations to set + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`SparkSession.Builder` + + Examples + -------- + For a (key, value) pair, you can omit parameter names. + + >>> SparkSession.builder.config("spark.some.config.option", "some-value") + >> SparkSession.builder.config( + ... map={"spark.some.config.number": 123, "spark.some.config.float": 0.123}) + "SparkSession.Builder": + return self + + def appName(self, name: str) -> "SparkSession.Builder": + """Sets a name for the application, which will be shown in the Spark web UI. + + If no application name is set, a randomly generated name will be used. + + .. versionadded:: 2.0.0 + + Parameters + ---------- + name : str + an application name + + Returns + ------- + :class:`SparkSession.Builder` + + Examples + -------- + >>> SparkSession.builder.appName("My app") + "SparkSession.Builder": + return self.config("spark.connect.location", location) + + def enableHiveSupport(self) -> "SparkSession.Builder": + raise NotImplementedError("enableHiveSupport not implemented for Spark Connect") + + def getOrCreate(self) -> "SparkSession": + """Creates a new instance.""" + return SparkSession(connectionString=self._options["spark.connect.location"]) + + _client: SparkConnectClient + + # TODO(SPARK-38912): Replace @classproperty with @classmethod + @property once support for + # Python 3.8 is dropped. + # + # In Python 3.9, the @property decorator has been made compatible with the + # @classmethod decorator (https://docs.python.org/3.9/library/functions.html#classmethod) + # + # @classmethod + @property is also affected by a bug in Python's docstring which was backported + # to Python 3.9.6 (https://github.com/python/cpython/pull/28838) + @classproperty + def builder(cls) -> Builder: + """Creates a :class:`Builder` for constructing a :class:`SparkSession`.""" + return cls.Builder() + + def __init__(self, connectionString: str, userId: Optional[str] = None): + """ + Creates a new SparkSession for the Spark Connect interface. + + Parameters + ---------- + connectionString: Optional[str] + Connection string that is used to extract the connection parameters and configure + the GRPC connection. Defaults to `sc://localhost`. + userId : Optional[str] + Optional unique user ID that is used to differentiate multiple users and + isolate their Spark Sessions. If the `user_id` is not set, will default to + the $USER environment. Defining the user ID as part of the connection string + takes precedence. + """ + # Parse the connection string. + self._client = SparkConnectClient(connectionString) + + # Create the reader + self.read = DataFrameReader(self) + + @property + def client(self) -> "SparkConnectClient": + """ + Gives access to the Spark Connect client. In normal cases this is not necessary to be used + and only relevant for testing. + Returns + ------- + :class:`SparkConnectClient` + """ + return self._client + + def register_udf( + self, function: Any, return_type: Union[str, pyspark.sql.types.DataType] + ) -> str: + return self._client.register_udf(function, return_type) + + def sql(self, sql_string: str) -> "DataFrame": + return DataFrame.withPlan(SQL(sql_string), self) + + def range( + self, + start: int, + end: int, + step: int = 1, + numPartitions: Optional[int] = None, + ) -> DataFrame: + """ + Create a :class:`DataFrame` with column named ``id`` and typed Long, + containing elements in a range from ``start`` to ``end`` (exclusive) with + step value ``step``. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + start : int + the start value + end : int + the end value (exclusive) + step : int, optional + the incremental step (default: 1) + numPartitions : int, optional + the number of partitions of the DataFrame + + Returns + ------- + :class:`DataFrame` + """ + return DataFrame.withPlan( + Range(start=start, end=end, step=step, num_partitions=numPartitions), self + ) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 845d6ead567e9..150bbdb65ef15 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -30,7 +30,8 @@ from pyspark.sql.types import StructType, StructField, LongType, StringType if have_pandas: - from pyspark.sql.connect.client import RemoteSparkSession, ChannelBuilder + from pyspark.sql.connect.session import SparkSession as RemoteSparkSession + from pyspark.sql.connect.client import ChannelBuilder from pyspark.sql.connect.function_builder import udf from pyspark.sql.connect.functions import lit, col from pyspark.sql.dataframe import DataFrame @@ -79,7 +80,7 @@ def tearDownClass(cls: Any) -> None: @classmethod def spark_connect_load_test_data(cls: Any): # Setup Remote Spark Session - cls.connect = RemoteSparkSession(userId="test_user") + cls.connect = RemoteSparkSession.builder.remote().getOrCreate() df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"]) # Since we might create multiple Spark sessions, we need to create global temporary view # that is specifically maintained in the "global_temp" schema. diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 99b63482a2438..03966bd28df6d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -140,7 +140,7 @@ def test_column_alias(self) -> None: self.assertEqual("Alias(Column(a), (martin))", str(col0)) col0 = fun.col("a").alias("martin", metadata={"pii": True}) - plan = col0.to_plan(self.session) + plan = col0.to_plan(self.session.client) self.assertIsNotNone(plan) self.assertEqual(plan.alias.metadata, '{"pii": true}') diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index f98a67b9964b4..feca9e9f82559 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -26,7 +26,7 @@ from pyspark.sql.connect.plan import Read, Range, SQL from pyspark.testing.utils import search_jar from pyspark.sql.connect.plan import LogicalPlan - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.session import SparkSession connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect") else: @@ -69,7 +69,7 @@ def __getattr__(self, item: str) -> Any: class PlanOnlyTestFixture(unittest.TestCase): connect: "MockRemoteSession" - session: RemoteSparkSession + session: SparkSession @classmethod def _read_table(cls, table_name: str) -> "DataFrame": @@ -102,7 +102,7 @@ def _with_plan(cls, plan: LogicalPlan) -> "DataFrame": @classmethod def setUpClass(cls: Any) -> None: cls.connect = MockRemoteSession() - cls.session = RemoteSparkSession() + cls.session = SparkSession.builder.remote().getOrCreate() cls.tbl_name = "test_connect_plan_only_table_1" cls.connect.set_hook("register_udf", cls._udf_mock)