From 5a56c17283103821714ffaaf1c764e05d0ff6b58 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 29 Mar 2023 09:58:52 +0900 Subject: [PATCH] [SPARK-42907][CONNECT][PYTHON] Implement Avro functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Implement Avro functions ### Why are the changes needed? For function parity ### Does this PR introduce _any_ user-facing change? yes, new APIs ### How was this patch tested? added doctest and manually check ``` (spark_dev) ➜ spark git:(connect_avro_functions) ✗ bin/pyspark --remote "local[*]" --jars connector/avro/target/scala-2.12/spark-avro_2.12-3.5.0-SNAPSHOT.jar Python 3.9.16 (main, Mar 8 2023, 04:29:24) Type 'copyright', 'credits' or 'license' for more information IPython 8.11.0 -- An enhanced Interactive Python. Type '?' for help. 23/03/23 16:28:50 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 3.5.0.dev0 /_/ Using Python version 3.9.16 (main, Mar 8 2023 04:29:24) Client connected to the Spark Connect server at localhost SparkSession available as 'spark'. In [1]: >>> from pyspark.sql import Row ...: >>> from pyspark.sql.avro.functions import from_avro, to_avro ...: >>> data = [(1, Row(age=2, name='Alice'))] ...: >>> df = spark.createDataFrame(data, ("key", "value")) ...: >>> avroDf = df.select(to_avro(df.value).alias("avro")) In [2]: avroDf.collect() Out[2]: [Row(avro=bytearray(b'\x00\x00\x04\x00\nAlice'))] ``` Closes #40535 from zhengruifeng/connect_avro_functions. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- assembly/pom.xml | 6 + .../spark/sql/avro/AvroDataToCatalyst.scala | 2 +- .../spark/sql/avro/CatalystDataToAvro.scala | 2 +- connector/connect/server/pom.xml | 6 + .../connect/planner/SparkConnectPlanner.scala | 35 ++++++ dev/sparktestsupport/modules.py | 3 +- python/pyspark/sql/avro/functions.py | 10 +- python/pyspark/sql/connect/avro/__init__.py | 18 +++ python/pyspark/sql/connect/avro/functions.py | 114 ++++++++++++++++++ python/pyspark/sql/utils.py | 16 +++ 10 files changed, 208 insertions(+), 4 deletions(-) create mode 100644 python/pyspark/sql/connect/avro/__init__.py create mode 100644 python/pyspark/sql/connect/avro/functions.py diff --git a/assembly/pom.xml b/assembly/pom.xml index 36cc607843831..09d6bd8a33f79 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -160,6 +160,12 @@ spark-connect_${scala.binary.version} ${project.version} + + org.apache.spark + spark-avro_${scala.binary.version} + ${project.version} + provided + diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index c4a4b16b05228..f8718edd97fdb 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.types._ -private[avro] case class AvroDataToCatalyst( +private[sql] case class AvroDataToCatalyst( child: Expression, jsonFormatSchema: String, options: Map[String, String]) diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala index 1e7e8600977e6..56ed117aef580 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{BinaryType, DataType} -private[avro] case class CatalystDataToAvro( +private[sql] case class CatalystDataToAvro( child: Expression, jsonFormatSchema: Option[String]) extends UnaryExpression { diff --git a/connector/connect/server/pom.xml b/connector/connect/server/pom.xml index 838d7bf2bd38c..a62c420bcc0f4 100644 --- a/connector/connect/server/pom.xml +++ b/connector/connect/server/pom.xml @@ -105,6 +105,12 @@ + + org.apache.spark + spark-avro_${scala.binary.version} + ${project.version} + provided + org.apache.spark spark-catalyst_${scala.binary.version} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index e7e88cab64378..d5baca9e17f80 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -32,6 +32,7 @@ import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult import org.apache.spark.connect.proto.Parse.ParseFormat import org.apache.spark.ml.{functions => MLFunctions} import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -1256,6 +1257,40 @@ class SparkConnectPlanner(val session: SparkSession) { None } + // Avro-specific functions + case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) => + val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) + val jsonFormatSchema = children(1) match { + case Literal(s, StringType) if s != null => s.toString + case other => + throw InvalidPlanInput( + s"jsonFormatSchema in from_avro should be a literal string, but got $other") + } + var options = Map.empty[String, String] + if (fun.getArgumentsCount == 3) { + children(2) match { + case UnresolvedFunction(Seq("map"), arguments, _, _, _) => + options = ExprUtils.convertToMapData(CreateMap(arguments)) + case other => + throw InvalidPlanInput( + s"Options in from_json should be created by map, but got $other") + } + } + Some(AvroDataToCatalyst(children.head, jsonFormatSchema, options)) + + case "to_avro" if Seq(1, 2).contains(fun.getArgumentsCount) => + val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) + var jsonFormatSchema = Option.empty[String] + if (fun.getArgumentsCount == 2) { + children(1) match { + case Literal(s, StringType) if s != null => jsonFormatSchema = Some(s.toString) + case other => + throw InvalidPlanInput( + s"jsonFormatSchema in to_avro should be a literal string, but got $other") + } + } + Some(CatalystDataToAvro(children.head, jsonFormatSchema)) + // PS(Pandas API on Spark)-specific functions case "distributed_sequence_id" if fun.getArgumentsCount == 0 => Some(DistributedSequenceID()) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 11257841bce59..f65ef7e3ac0c2 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -273,7 +273,7 @@ def __hash__(self): connect = Module( name="connect", - dependencies=[hive], + dependencies=[hive, avro], source_file_regexes=[ "connector/connect", ], @@ -748,6 +748,7 @@ def __hash__(self): "pyspark.sql.connect.readwriter", "pyspark.sql.connect.dataframe", "pyspark.sql.connect.functions", + "pyspark.sql.connect.avro.functions", # sql unittests "pyspark.sql.tests.connect.test_client", "pyspark.sql.tests.connect.test_connect_plan", diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index 080e45934e65d..e49953e8953b8 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -25,13 +25,14 @@ from py4j.java_gateway import JVMView from pyspark.sql.column import Column, _to_java_column -from pyspark.sql.utils import get_active_spark_context +from pyspark.sql.utils import get_active_spark_context, try_remote_avro_functions from pyspark.util import _print_missing_jar if TYPE_CHECKING: from pyspark.sql._typing import ColumnOrName +@try_remote_avro_functions def from_avro( data: "ColumnOrName", jsonFormatSchema: str, options: Optional[Dict[str, str]] = None ) -> Column: @@ -44,6 +45,9 @@ def from_avro( .. versionadded:: 3.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- data : :class:`~pyspark.sql.Column` or str @@ -88,12 +92,16 @@ def from_avro( return Column(jc) +@try_remote_avro_functions def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column: """ Converts a column into binary of avro format. .. versionadded:: 3.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- data : :class:`~pyspark.sql.Column` or str diff --git a/python/pyspark/sql/connect/avro/__init__.py b/python/pyspark/sql/connect/avro/__init__.py new file mode 100644 index 0000000000000..6d29d44cb9cfa --- /dev/null +++ b/python/pyspark/sql/connect/avro/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Spark Connect Python Client - Avro Functions""" diff --git a/python/pyspark/sql/connect/avro/functions.py b/python/pyspark/sql/connect/avro/functions.py new file mode 100644 index 0000000000000..acd7fa6305438 --- /dev/null +++ b/python/pyspark/sql/connect/avro/functions.py @@ -0,0 +1,114 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A collections of builtin avro functions +""" + +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) + +from typing import Dict, Optional, TYPE_CHECKING + +from pyspark.sql.avro import functions as PyAvroFunctions + +from pyspark.sql.connect.column import Column +from pyspark.sql.connect.functions import _invoke_function, _to_col, _options_to_col, lit + +if TYPE_CHECKING: + from pyspark.sql.connect._typing import ColumnOrName + + +def from_avro( + data: "ColumnOrName", jsonFormatSchema: str, options: Optional[Dict[str, str]] = None +) -> Column: + if options is None: + return _invoke_function("from_avro", _to_col(data), lit(jsonFormatSchema)) + else: + return _invoke_function( + "from_avro", _to_col(data), lit(jsonFormatSchema), _options_to_col(options) + ) + + +from_avro.__doc__ = PyAvroFunctions.from_avro.__doc__ + + +def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column: + if jsonFormatSchema == "": + return _invoke_function("to_avro", _to_col(data)) + else: + return _invoke_function("to_avro", _to_col(data), lit(jsonFormatSchema)) + + +to_avro.__doc__ = PyAvroFunctions.to_avro.__doc__ + + +def _test() -> None: + import os + import sys + from pyspark.testing.utils import search_jar + + avro_jar = search_jar("connector/avro", "spark-avro", "spark-avro") + + print() + print(avro_jar) + print(avro_jar) + print(avro_jar) + print() + + if avro_jar is None: + print( + "Skipping all Avro Python tests as the optional Avro project was " + "not compiled into a JAR. To run these tests, " + "you need to build Spark with 'build/sbt -Pavro package' or " + "'build/mvn -Pavro package' before running this test." + ) + sys.exit(0) + else: + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % avro_jar + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) + + import doctest + from pyspark.sql import SparkSession as PySparkSession + import pyspark.sql.connect.avro.functions + + globs = pyspark.sql.connect.avro.functions.__dict__.copy() + + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.avro.functions tests") + .remote("local[4]") + .getOrCreate() + ) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.connect.avro.functions, + globs=globs, + optionflags=doctest.ELLIPSIS + | doctest.NORMALIZE_WHITESPACE + | doctest.IGNORE_EXCEPTION_DETAIL, + ) + + globs["spark"].stop() + + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index b5d17e38b8734..6f75325e0d8bf 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -161,6 +161,22 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return cast(FuncT, wrapped) +def try_remote_avro_functions(f: FuncT) -> FuncT: + """Mark API supported from Spark Connect.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + from pyspark.sql.connect.avro import functions + + return getattr(functions, f.__name__)(*args, **kwargs) + else: + return f(*args, **kwargs) + + return cast(FuncT, wrapped) + + def try_remote_window(f: FuncT) -> FuncT: """Mark API supported from Spark Connect."""