diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index b41df12c357c5..a2a0797c49fa2 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -15,23 +15,19 @@ # limitations under the License. # - -import logging import os import urllib.parse import uuid +from typing import Iterable, Optional, Any, Union, List, Tuple, Dict import grpc # type: ignore -import pyarrow as pa import pandas +import pyarrow as pa import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib import pyspark.sql.types from pyspark import cloudpickle -from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.connect.readwriter import DataFrameReader -from pyspark.sql.connect.plan import SQL, Range from pyspark.sql.types import ( DataType, ByteType, @@ -56,10 +52,6 @@ NullType, ) -from typing import Iterable, Optional, Any, Union, List, Tuple, Dict - -logging.basicConfig(level=logging.INFO) - class ChannelBuilder: """ @@ -294,12 +286,12 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult": ) -class RemoteSparkSession(object): +class SparkConnectClient(object): """Conceptually the remote spark session that communicates with the server""" - def __init__(self, connectionString: str = "sc://localhost", userId: Optional[str] = None): + def __init__(self, connectionString: str, userId: Optional[str] = None): """ - Creates a new RemoteSparkSession for the Spark Connect interface. + Creates a new SparkSession for the Spark Connect interface. Parameters ---------- @@ -325,9 +317,6 @@ def __init__(self, connectionString: str = "sc://localhost", userId: Optional[st self._channel = self._builder.toChannel() self._stub = grpc_lib.SparkConnectServiceStub(self._channel) - # Create the reader - self.read = DataFrameReader(self) - def register_udf( self, function: Any, return_type: Union[str, pyspark.sql.types.DataType] ) -> str: @@ -355,42 +344,6 @@ def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[Pla for x in metrics.metrics ] - def sql(self, sql_string: str) -> "DataFrame": - return DataFrame.withPlan(SQL(sql_string), self) - - def range( - self, - start: int, - end: int, - step: int = 1, - numPartitions: Optional[int] = None, - ) -> DataFrame: - """ - Create a :class:`DataFrame` with column named ``id`` and typed Long, - containing elements in a range from ``start`` to ``end`` (exclusive) with - step value ``step``. - - .. versionadded:: 3.4.0 - - Parameters - ---------- - start : int - the start value - end : int - the end value (exclusive) - step : int, optional - the incremental step (default: 1) - numPartitions : int, optional - the number of partitions of the DataFrame - - Returns - ------- - :class:`DataFrame` - """ - return DataFrame.withPlan( - Range(start=start, end=end, step=step, num_partitions=numPartitions), self - ) - def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame": req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 36f38e0ded286..69f9fa72db6db 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -24,7 +24,7 @@ import pyspark.sql.connect.proto as proto if TYPE_CHECKING: - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.client import SparkConnectClient import pyspark.sql.connect.proto as proto @@ -80,7 +80,7 @@ def __eq__(self, other: Any) -> "Expression": # type: ignore[override] def __init__(self) -> None: pass - def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": ... def __str__(self) -> str: @@ -131,7 +131,7 @@ def __init__(self, parent: Expression, alias: list[str], metadata: Any): self._metadata = metadata self._parent = parent - def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": if len(self._alias) == 1: exp = proto.Expression() exp.alias.name.append(self._alias[0]) @@ -162,7 +162,7 @@ def __init__(self, value: Any) -> None: super().__init__() self._value = value - def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression": + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": """Converts the literal expression to the literal in proto. TODO(SPARK-40533) This method always assumes the largest type and can thus @@ -250,7 +250,7 @@ def name(self) -> str: """Returns the qualified name of the column reference.""" return self._unparsed_identifier - def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: """Returns the Proto representation of the expression.""" expr = proto.Expression() expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier @@ -275,7 +275,7 @@ def __init__(self, expr: str) -> None: super().__init__() self._expr: str = expr - def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: """Returns the Proto representation of the SQL expression.""" expr = proto.Expression() expr.expression_string.expression = self._expr @@ -292,7 +292,7 @@ def __init__(self, col: Column, ascending: bool = True, nullsLast: bool = True) def __str__(self) -> str: return str(self.ref) + " ASC" if self.ascending else " DESC" - def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return self.ref.to_plan(session) @@ -306,7 +306,7 @@ def __init__( self._args = args self._op = op - def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: fun = proto.Expression() fun.unresolved_function.parts.append(self._op) fun.unresolved_function.arguments.extend([x.to_plan(session) for x in self._args]) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 23340e461658e..6fabab69cf53c 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -46,7 +46,7 @@ if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString, LiteralType - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.session import SparkSession class GroupedData(object): @@ -97,20 +97,20 @@ class DataFrame(object): def __init__( self, - session: "RemoteSparkSession", + session: "SparkSession", data: Optional[List[Any]] = None, schema: Optional[StructType] = None, ): """Creates a new data frame""" self._schema = schema self._plan: Optional[plan.LogicalPlan] = None - self._session: "RemoteSparkSession" = session + self._session: "SparkSession" = session def __repr__(self) -> str: return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @classmethod - def withPlan(cls, plan: plan.LogicalPlan, session: "RemoteSparkSession") -> "DataFrame": + def withPlan(cls, plan: plan.LogicalPlan, session: "SparkSession") -> "DataFrame": """Main initialization method used to construct a new data frame with a child plan.""" new_frame = DataFrame(session=session) new_frame._plan = plan @@ -197,14 +197,14 @@ def columns(self) -> List[str]: return self.schema.names - def sparkSession(self) -> "RemoteSparkSession": + def sparkSession(self) -> "SparkSession": """Returns Spark session that created this :class:`DataFrame`. .. versionadded:: 3.4.0 Returns ------- - :class:`RemoteSparkSession` + :class:`SparkSession` """ return self._session @@ -796,8 +796,8 @@ def toPandas(self) -> "pandas.DataFrame": raise Exception("Cannot collect on empty plan.") if self._session is None: raise Exception("Cannot collect on empty session.") - query = self._plan.to_proto(self._session) - return self._session._to_pandas(query) + query = self._plan.to_proto(self._session.client) + return self._session.client._to_pandas(query) @property def schema(self) -> StructType: @@ -811,10 +811,10 @@ def schema(self) -> StructType: """ if self._schema is None: if self._plan is not None: - query = self._plan.to_proto(self._session) + query = self._plan.to_proto(self._session.client) if self._session is None: - raise Exception("Cannot analyze without RemoteSparkSession.") - self._schema = self._session.schema(query) + raise Exception("Cannot analyze without SparkSession.") + self._schema = self._session.client.schema(query) return self._schema else: raise Exception("Empty plan.") @@ -834,8 +834,8 @@ def isLocal(self) -> bool: """ if self._plan is None: raise Exception("Cannot analyze on empty plan.") - query = self._plan.to_proto(self._session) - return self._session._analyze(query).is_local + query = self._plan.to_proto(self._session.client) + return self._session.client._analyze(query).is_local @property def isStreaming(self) -> bool: @@ -859,14 +859,14 @@ def isStreaming(self) -> bool: """ if self._plan is None: raise Exception("Cannot analyze on empty plan.") - query = self._plan.to_proto(self._session) - return self._session._analyze(query).is_streaming + query = self._plan.to_proto(self._session.client) + return self._session.client._analyze(query).is_streaming def _tree_string(self) -> str: if self._plan is None: raise Exception("Cannot analyze on empty plan.") - query = self._plan.to_proto(self._session) - return self._session._analyze(query).tree_string + query = self._plan.to_proto(self._session.client) + return self._session.client._analyze(query).tree_string def printSchema(self) -> None: """Prints out the schema in the tree format. @@ -895,8 +895,8 @@ def inputFiles(self) -> List[str]: """ if self._plan is None: raise Exception("Cannot analyze on empty plan.") - query = self._plan.to_proto(self._session) - return self._session._analyze(query).input_files + query = self._plan.to_proto(self._session.client) + return self._session.client._analyze(query).input_files def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame": """Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations. @@ -1011,10 +1011,10 @@ def explain( explain_mode = cast(str, extended) if self._plan is not None: - query = self._plan.to_proto(self._session) + query = self._plan.to_proto(self._session.client) if self._session is None: - raise Exception("Cannot analyze without RemoteSparkSession.") - return self._session.explain_string(query, explain_mode) + raise Exception("Cannot analyze without SparkSession.") + return self._session.client.explain_string(query, explain_mode) else: return "" @@ -1032,8 +1032,8 @@ def createGlobalTempView(self, name: str) -> None: """ command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=False - ).command(session=self._session) - self._session.execute_command(command) + ).command(session=self._session.client) + self._session.client.execute_command(command) def createOrReplaceGlobalTempView(self, name: str) -> None: """Creates or replaces a global temporary view using the given name. @@ -1049,8 +1049,8 @@ def createOrReplaceGlobalTempView(self, name: str) -> None: """ command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=True - ).command(session=self._session) - self._session.execute_command(command) + ).command(session=self._session.client) + self._session.client.execute_command(command) def rdd(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("RDD Support for Spark Connect is not implemented.") diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py index 4a2688d6a0daf..8df5e56b452a6 100644 --- a/python/pyspark/sql/connect/function_builder.py +++ b/python/pyspark/sql/connect/function_builder.py @@ -34,7 +34,7 @@ FunctionBuilderCallable, UserDefinedFunctionCallable, ) - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.client import SparkConnectClient def _build(name: str, *args: "ExpressionOrString") -> ScalarFunctionExpression: @@ -91,7 +91,7 @@ def __init__( self._args = [] self._func_name = None - def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: if session is None: raise Exception("CAnnot create UDF without remote Session.") # Needs to materialize the UDF to the server diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 853b1a6dc0e1a..9a22d6ea38ecc 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.client import SparkConnectClient class InputValidationError(Exception): @@ -57,7 +57,7 @@ def unresolved_attr(self, colName: str) -> proto.Expression: return exp def to_attr_or_expression( - self, col: "ColumnOrName", session: "RemoteSparkSession" + self, col: "ColumnOrName", session: "SparkConnectClient" ) -> proto.Expression: """Returns either an instance of an unresolved attribute or the serialized expression value of the column.""" @@ -66,13 +66,13 @@ def to_attr_or_expression( else: return cast(Column, col).to_plan(session) - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: ... - def command(self, session: "RemoteSparkSession") -> proto.Command: + def command(self, session: "SparkConnectClient") -> proto.Command: ... - def _verify(self, session: "RemoteSparkSession") -> bool: + def _verify(self, session: "SparkConnectClient") -> bool: """This method is used to verify that the current logical plan can be serialized to Proto and back and afterwards is identical.""" plan = proto.Plan() @@ -84,13 +84,13 @@ def _verify(self, session: "RemoteSparkSession") -> bool: return test_plan == plan - def to_proto(self, session: "RemoteSparkSession", debug: bool = False) -> proto.Plan: + def to_proto(self, session: "SparkConnectClient", debug: bool = False) -> proto.Plan: """ Generates connect proto plan based on this LogicalPlan. Parameters ---------- - session : :class:`RemoteSparkSession`, optional. + session : :class:`SparkConnectClient`, optional. a session that connects remote spark cluster. debug: bool if enabled, the proto plan will be printed. @@ -127,7 +127,7 @@ def __init__( self.schema = schema self.options = options - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = proto.Relation() if self.format is not None: plan.read.data_source.format = self.format @@ -158,7 +158,7 @@ def __init__(self, table_name: str) -> None: super().__init__(None) self.table_name = table_name - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = proto.Relation() plan.read.named_table.unparsed_identifier = self.table_name return plan @@ -186,7 +186,7 @@ def __init__( self.truncate = truncate self.vertical = vertical - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.show_string.input.CopyFrom(self._child.plan(session)) @@ -242,7 +242,7 @@ def _verify_expressions(self) -> None: f"Only Expressions or String can be used for projections: '{c}'." ) - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None proj_exprs = [] for c in self._raw_columns: @@ -281,7 +281,7 @@ def __init__(self, child: Optional["LogicalPlan"], filter: Expression) -> None: super().__init__(child) self.filter = filter - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.filter.input.CopyFrom(self._child.plan(session)) @@ -309,7 +309,7 @@ def __init__(self, child: Optional["LogicalPlan"], limit: int) -> None: super().__init__(child) self.limit = limit - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.limit.input.CopyFrom(self._child.plan(session)) @@ -337,7 +337,7 @@ def __init__(self, child: Optional["LogicalPlan"], offset: int = 0) -> None: super().__init__(child) self.offset = offset - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.offset.input.CopyFrom(self._child.plan(session)) @@ -371,7 +371,7 @@ def __init__( self.all_columns_as_keys = all_columns_as_keys self.column_names = column_names - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys @@ -411,7 +411,7 @@ def __init__( self.is_global = is_global def col_to_sort_field( - self, col: Union[SortOrder, Column, str], session: "RemoteSparkSession" + self, col: Union[SortOrder, Column, str], session: "SparkConnectClient" ) -> proto.Sort.SortField: if isinstance(col, SortOrder): sf = proto.Sort.SortField() @@ -438,7 +438,7 @@ def col_to_sort_field( sf.nulls = proto.Sort.SortNulls.SORT_NULLS_LAST return sf - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.sort.input.CopyFrom(self._child.plan(session)) @@ -474,7 +474,7 @@ def __init__( self.columns = columns def _convert_to_expr( - self, col: Union[Column, str], session: "RemoteSparkSession" + self, col: Union[Column, str], session: "SparkConnectClient" ) -> proto.Expression: expr = proto.Expression() if isinstance(col, Column): @@ -483,7 +483,7 @@ def _convert_to_expr( expr.CopyFrom(self.unresolved_attr(col)) return expr - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.drop.input.CopyFrom(self._child.plan(session)) @@ -521,7 +521,7 @@ def __init__( self.with_replacement = with_replacement self.seed = seed - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.sample.input.CopyFrom(self._child.plan(session)) @@ -567,12 +567,12 @@ def __init__( self.grouping_cols = grouping_cols self.measures = measures - def _convert_measure(self, m: Expression, session: "RemoteSparkSession") -> proto.Expression: + def _convert_measure(self, m: Expression, session: "SparkConnectClient") -> proto.Expression: proto_expr = proto.Expression() proto_expr.CopyFrom(m.to_plan(session)) return proto_expr - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None groupings = [x.to_plan(session) for x in self.grouping_cols] @@ -642,7 +642,7 @@ def __init__( ) self.how = join_type - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: rel = proto.Relation() rel.join.left.CopyFrom(self.left.plan(session)) rel.join.right.CopyFrom(self.right.plan(session)) @@ -693,7 +693,7 @@ def __init__( self.is_all = is_all self.set_op = set_op - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None rel = proto.Relation() if self._child is not None: @@ -753,7 +753,7 @@ def __init__(self, child: Optional["LogicalPlan"], num_partitions: int, shuffle: self._num_partitions = num_partitions self._shuffle = shuffle - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: rel = proto.Relation() if self._child is not None: rel.repartition.input.CopyFrom(self._child.plan(session)) @@ -786,7 +786,7 @@ def __init__(self, child: Optional["LogicalPlan"], alias: str) -> None: super().__init__(child) self._alias = alias - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: rel = proto.Relation() if self._child is not None: rel.subquery_alias.input.CopyFrom(self._child.plan(session)) @@ -814,7 +814,7 @@ def __init__(self, query: str) -> None: super().__init__(None) self._query = query - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: rel = proto.Relation() rel.sql.query = self._query return rel @@ -849,7 +849,7 @@ def __init__( self._step = step self._num_partitions = num_partitions - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: rel = proto.Relation() rel.range.start = self._start rel.range.end = self._end @@ -912,7 +912,7 @@ def _convert_value(self, v: Any) -> proto.Expression.Literal: value.string = v return value - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.fill_na.input.CopyFrom(self._child.plan(session)) @@ -942,7 +942,7 @@ def __init__(self, child: Optional["LogicalPlan"], statistics: List[str]) -> Non super().__init__(child) self.statistics = statistics - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() plan.summary.input.CopyFrom(self._child.plan(session)) @@ -971,7 +971,7 @@ def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str) -> None self.col1 = col1 self.col2 = col2 - def plan(self, session: "RemoteSparkSession") -> proto.Relation: + def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = proto.Relation() @@ -1006,7 +1006,7 @@ def __init__( self._is_gloal = is_global self._replace = replace - def command(self, session: "RemoteSparkSession") -> proto.Command: + def command(self, session: "SparkConnectClient") -> proto.Command: assert self._child is not None plan = proto.Command() diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 27aa023ae474f..ead027c206bdf 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: from pyspark.sql.connect._typing import OptionalPrimitiveType - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.session import SparkSession class DataFrameReader: @@ -34,7 +34,7 @@ class DataFrameReader: TODO(SPARK-40539) Achieve parity with PySpark. """ - def __init__(self, client: "RemoteSparkSession"): + def __init__(self, client: "SparkSession"): self._client = client self._format = "" self._schema = "" diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py new file mode 100644 index 0000000000000..92f58140eaccd --- /dev/null +++ b/python/pyspark/sql/connect/session.py @@ -0,0 +1,258 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from threading import RLock +from typing import Optional, Any, Union, Dict, cast, overload + +import pyspark.sql.types +from pyspark.sql.connect.client import SparkConnectClient +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.connect.plan import SQL, Range +from pyspark.sql.connect.readwriter import DataFrameReader +from pyspark.sql.utils import to_str +from ._typing import OptionalPrimitiveType + + +# TODO(SPARK-38912): This method can be dropped once support for Python 3.8 is dropped +# In Python 3.9, the @property decorator has been made compatible with the +# @classmethod decorator (https://docs.python.org/3.9/library/functions.html#classmethod) +# +# @classmethod + @property is also affected by a bug in Python's docstring which was backported +# to Python 3.9.6 (https://github.com/python/cpython/pull/28838) +class classproperty(property): + """Same as Python's @property decorator, but for class attributes. + + Examples + -------- + >>> class Builder: + ... def build(self): + ... return MyClass() + ... + >>> class MyClass: + ... @classproperty + ... def builder(cls): + ... print("instantiating new builder") + ... return Builder() + ... + >>> c1 = MyClass.builder + instantiating new builder + >>> c2 = MyClass.builder + instantiating new builder + >>> c1 == c2 + False + >>> isinstance(c1.build(), MyClass) + True + """ + + def __get__(self, instance: Any, owner: Any = None) -> "SparkSession.Builder": + # The "type: ignore" below silences the following error from mypy: + # error: Argument 1 to "classmethod" has incompatible + # type "Optional[Callable[[Any], Any]]"; + # expected "Callable[..., Any]" [arg-type] + return classmethod(self.fget).__get__(None, owner)() # type: ignore + + +class SparkSession(object): + """Conceptually the remote spark session that communicates with the server""" + + class Builder: + """Builder for :class:`SparkSession`.""" + + _lock = RLock() + + def __init__(self) -> None: + self._options: Dict[str, Any] = {} + + @overload + def config(self, key: str, value: Any) -> "SparkSession.Builder": + ... + + @overload + def config(self, *, map: Dict[str, "OptionalPrimitiveType"]) -> "SparkSession.Builder": + ... + + def config( + self, + key: Optional[str] = None, + value: Optional[Any] = None, + *, + map: Optional[Dict[str, "OptionalPrimitiveType"]] = None, + ) -> "SparkSession.Builder": + """Sets a config option. Options set using this method are automatically propagated to + both :class:`SparkConf` and :class:`SparkSession`'s own configuration. + + .. versionadded:: 2.0.0 + + Parameters + ---------- + key : str, optional + a key name string for configuration property + value : str, optional + a value for configuration property + map: dictionary, optional + a dictionary of configurations to set + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`SparkSession.Builder` + + Examples + -------- + For a (key, value) pair, you can omit parameter names. + + >>> SparkSession.builder.config("spark.some.config.option", "some-value") + >> SparkSession.builder.config( + ... map={"spark.some.config.number": 123, "spark.some.config.float": 0.123}) + "SparkSession.Builder": + return self + + def appName(self, name: str) -> "SparkSession.Builder": + """Sets a name for the application, which will be shown in the Spark web UI. + + If no application name is set, a randomly generated name will be used. + + .. versionadded:: 2.0.0 + + Parameters + ---------- + name : str + an application name + + Returns + ------- + :class:`SparkSession.Builder` + + Examples + -------- + >>> SparkSession.builder.appName("My app") + "SparkSession.Builder": + return self.config("spark.connect.location", location) + + def enableHiveSupport(self) -> "SparkSession.Builder": + raise NotImplementedError("enableHiveSupport not implemented for Spark Connect") + + def getOrCreate(self) -> "SparkSession": + """Creates a new instance.""" + return SparkSession(connectionString=self._options["spark.connect.location"]) + + _client: SparkConnectClient + + # TODO(SPARK-38912): Replace @classproperty with @classmethod + @property once support for + # Python 3.8 is dropped. + # + # In Python 3.9, the @property decorator has been made compatible with the + # @classmethod decorator (https://docs.python.org/3.9/library/functions.html#classmethod) + # + # @classmethod + @property is also affected by a bug in Python's docstring which was backported + # to Python 3.9.6 (https://github.com/python/cpython/pull/28838) + @classproperty + def builder(cls) -> Builder: + """Creates a :class:`Builder` for constructing a :class:`SparkSession`.""" + return cls.Builder() + + def __init__(self, connectionString: str, userId: Optional[str] = None): + """ + Creates a new SparkSession for the Spark Connect interface. + + Parameters + ---------- + connectionString: Optional[str] + Connection string that is used to extract the connection parameters and configure + the GRPC connection. Defaults to `sc://localhost`. + userId : Optional[str] + Optional unique user ID that is used to differentiate multiple users and + isolate their Spark Sessions. If the `user_id` is not set, will default to + the $USER environment. Defining the user ID as part of the connection string + takes precedence. + """ + # Parse the connection string. + self._client = SparkConnectClient(connectionString) + + # Create the reader + self.read = DataFrameReader(self) + + @property + def client(self) -> "SparkConnectClient": + """ + Gives access to the Spark Connect client. In normal cases this is not necessary to be used + and only relevant for testing. + Returns + ------- + :class:`SparkConnectClient` + """ + return self._client + + def register_udf( + self, function: Any, return_type: Union[str, pyspark.sql.types.DataType] + ) -> str: + return self._client.register_udf(function, return_type) + + def sql(self, sql_string: str) -> "DataFrame": + return DataFrame.withPlan(SQL(sql_string), self) + + def range( + self, + start: int, + end: int, + step: int = 1, + numPartitions: Optional[int] = None, + ) -> DataFrame: + """ + Create a :class:`DataFrame` with column named ``id`` and typed Long, + containing elements in a range from ``start`` to ``end`` (exclusive) with + step value ``step``. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + start : int + the start value + end : int + the end value (exclusive) + step : int, optional + the incremental step (default: 1) + numPartitions : int, optional + the number of partitions of the DataFrame + + Returns + ------- + :class:`DataFrame` + """ + return DataFrame.withPlan( + Range(start=start, end=end, step=step, num_partitions=numPartitions), self + ) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 845d6ead567e9..150bbdb65ef15 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -30,7 +30,8 @@ from pyspark.sql.types import StructType, StructField, LongType, StringType if have_pandas: - from pyspark.sql.connect.client import RemoteSparkSession, ChannelBuilder + from pyspark.sql.connect.session import SparkSession as RemoteSparkSession + from pyspark.sql.connect.client import ChannelBuilder from pyspark.sql.connect.function_builder import udf from pyspark.sql.connect.functions import lit, col from pyspark.sql.dataframe import DataFrame @@ -79,7 +80,7 @@ def tearDownClass(cls: Any) -> None: @classmethod def spark_connect_load_test_data(cls: Any): # Setup Remote Spark Session - cls.connect = RemoteSparkSession(userId="test_user") + cls.connect = RemoteSparkSession.builder.remote().getOrCreate() df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"]) # Since we might create multiple Spark sessions, we need to create global temporary view # that is specifically maintained in the "global_temp" schema. diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 99b63482a2438..03966bd28df6d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -140,7 +140,7 @@ def test_column_alias(self) -> None: self.assertEqual("Alias(Column(a), (martin))", str(col0)) col0 = fun.col("a").alias("martin", metadata={"pii": True}) - plan = col0.to_plan(self.session) + plan = col0.to_plan(self.session.client) self.assertIsNotNone(plan) self.assertEqual(plan.alias.metadata, '{"pii": true}') diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index f98a67b9964b4..feca9e9f82559 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -26,7 +26,7 @@ from pyspark.sql.connect.plan import Read, Range, SQL from pyspark.testing.utils import search_jar from pyspark.sql.connect.plan import LogicalPlan - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.session import SparkSession connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect") else: @@ -69,7 +69,7 @@ def __getattr__(self, item: str) -> Any: class PlanOnlyTestFixture(unittest.TestCase): connect: "MockRemoteSession" - session: RemoteSparkSession + session: SparkSession @classmethod def _read_table(cls, table_name: str) -> "DataFrame": @@ -102,7 +102,7 @@ def _with_plan(cls, plan: LogicalPlan) -> "DataFrame": @classmethod def setUpClass(cls: Any) -> None: cls.connect = MockRemoteSession() - cls.session = RemoteSparkSession() + cls.session = SparkSession.builder.remote().getOrCreate() cls.tbl_name = "test_connect_plan_only_table_1" cls.connect.set_hook("register_udf", cls._udf_mock)