Skip to content

Commit

Permalink
[SPARK-42907][CONNECT][PYTHON] Implement Avro functions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
 Implement Avro functions

### Why are the changes needed?
For function parity

### Does this PR introduce _any_ user-facing change?
yes, new APIs

### How was this patch tested?
added doctest and manually check
```
(spark_dev) ➜  spark git:(connect_avro_functions) ✗ bin/pyspark --remote "local[*]" --jars connector/avro/target/scala-2.12/spark-avro_2.12-3.5.0-SNAPSHOT.jar
Python 3.9.16 (main, Mar  8 2023, 04:29:24)
Type 'copyright', 'credits' or 'license' for more information
IPython 8.11.0 -- An enhanced Interactive Python. Type '?' for help.
23/03/23 16:28:50 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /__ / .__/\_,_/_/ /_/\_\   version 3.5.0.dev0
      /_/

Using Python version 3.9.16 (main, Mar  8 2023 04:29:24)
Client connected to the Spark Connect server at localhost
SparkSession available as 'spark'.

In [1]:     >>> from pyspark.sql import Row
   ...:     >>> from pyspark.sql.avro.functions import from_avro, to_avro
   ...:     >>> data = [(1, Row(age=2, name='Alice'))]
   ...:     >>> df = spark.createDataFrame(data, ("key", "value"))
   ...:     >>> avroDf = df.select(to_avro(df.value).alias("avro"))

In [2]: avroDf.collect()
Out[2]: [Row(avro=bytearray(b'\x00\x00\x04\x00\nAlice'))]
```

Closes apache#40535 from zhengruifeng/connect_avro_functions.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Mar 29, 2023
1 parent aacac46 commit 5a56c17
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 4 deletions.
6 changes: 6 additions & 0 deletions assembly/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@
<artifactId>spark-connect_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-avro_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
</profile>
<profile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe
import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode}
import org.apache.spark.sql.types._

private[avro] case class AvroDataToCatalyst(
private[sql] case class AvroDataToCatalyst(
child: Expression,
jsonFormatSchema: String,
options: Map[String, String])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{BinaryType, DataType}

private[avro] case class CatalystDataToAvro(
private[sql] case class CatalystDataToAvro(
child: Expression,
jsonFormatSchema: Option[String]) extends UnaryExpression {

Expand Down
6 changes: 6 additions & 0 deletions connector/connect/server/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-avro_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -1256,6 +1257,40 @@ class SparkConnectPlanner(val session: SparkSession) {
None
}

// Avro-specific functions
case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) =>
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
val jsonFormatSchema = children(1) match {
case Literal(s, StringType) if s != null => s.toString
case other =>
throw InvalidPlanInput(
s"jsonFormatSchema in from_avro should be a literal string, but got $other")
}
var options = Map.empty[String, String]
if (fun.getArgumentsCount == 3) {
children(2) match {
case UnresolvedFunction(Seq("map"), arguments, _, _, _) =>
options = ExprUtils.convertToMapData(CreateMap(arguments))
case other =>
throw InvalidPlanInput(
s"Options in from_json should be created by map, but got $other")
}
}
Some(AvroDataToCatalyst(children.head, jsonFormatSchema, options))

case "to_avro" if Seq(1, 2).contains(fun.getArgumentsCount) =>
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
var jsonFormatSchema = Option.empty[String]
if (fun.getArgumentsCount == 2) {
children(1) match {
case Literal(s, StringType) if s != null => jsonFormatSchema = Some(s.toString)
case other =>
throw InvalidPlanInput(
s"jsonFormatSchema in to_avro should be a literal string, but got $other")
}
}
Some(CatalystDataToAvro(children.head, jsonFormatSchema))

// PS(Pandas API on Spark)-specific functions
case "distributed_sequence_id" if fun.getArgumentsCount == 0 =>
Some(DistributedSequenceID())
Expand Down
3 changes: 2 additions & 1 deletion dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __hash__(self):

connect = Module(
name="connect",
dependencies=[hive],
dependencies=[hive, avro],
source_file_regexes=[
"connector/connect",
],
Expand Down Expand Up @@ -748,6 +748,7 @@ def __hash__(self):
"pyspark.sql.connect.readwriter",
"pyspark.sql.connect.dataframe",
"pyspark.sql.connect.functions",
"pyspark.sql.connect.avro.functions",
# sql unittests
"pyspark.sql.tests.connect.test_client",
"pyspark.sql.tests.connect.test_connect_plan",
Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/sql/avro/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
from py4j.java_gateway import JVMView

from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.utils import get_active_spark_context
from pyspark.sql.utils import get_active_spark_context, try_remote_avro_functions
from pyspark.util import _print_missing_jar

if TYPE_CHECKING:
from pyspark.sql._typing import ColumnOrName


@try_remote_avro_functions
def from_avro(
data: "ColumnOrName", jsonFormatSchema: str, options: Optional[Dict[str, str]] = None
) -> Column:
Expand All @@ -44,6 +45,9 @@ def from_avro(
.. versionadded:: 3.0.0
.. versionchanged:: 3.5.0
Supports Spark Connect.
Parameters
----------
data : :class:`~pyspark.sql.Column` or str
Expand Down Expand Up @@ -88,12 +92,16 @@ def from_avro(
return Column(jc)


@try_remote_avro_functions
def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
"""
Converts a column into binary of avro format.
.. versionadded:: 3.0.0
.. versionchanged:: 3.5.0
Supports Spark Connect.
Parameters
----------
data : :class:`~pyspark.sql.Column` or str
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/connect/avro/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Spark Connect Python Client - Avro Functions"""
114 changes: 114 additions & 0 deletions python/pyspark/sql/connect/avro/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
A collections of builtin avro functions
"""

from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)

from typing import Dict, Optional, TYPE_CHECKING

from pyspark.sql.avro import functions as PyAvroFunctions

from pyspark.sql.connect.column import Column
from pyspark.sql.connect.functions import _invoke_function, _to_col, _options_to_col, lit

if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName


def from_avro(
data: "ColumnOrName", jsonFormatSchema: str, options: Optional[Dict[str, str]] = None
) -> Column:
if options is None:
return _invoke_function("from_avro", _to_col(data), lit(jsonFormatSchema))
else:
return _invoke_function(
"from_avro", _to_col(data), lit(jsonFormatSchema), _options_to_col(options)
)


from_avro.__doc__ = PyAvroFunctions.from_avro.__doc__


def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
if jsonFormatSchema == "":
return _invoke_function("to_avro", _to_col(data))
else:
return _invoke_function("to_avro", _to_col(data), lit(jsonFormatSchema))


to_avro.__doc__ = PyAvroFunctions.to_avro.__doc__


def _test() -> None:
import os
import sys
from pyspark.testing.utils import search_jar

avro_jar = search_jar("connector/avro", "spark-avro", "spark-avro")

print()
print(avro_jar)
print(avro_jar)
print(avro_jar)
print()

if avro_jar is None:
print(
"Skipping all Avro Python tests as the optional Avro project was "
"not compiled into a JAR. To run these tests, "
"you need to build Spark with 'build/sbt -Pavro package' or "
"'build/mvn -Pavro package' before running this test."
)
sys.exit(0)
else:
existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
jars_args = "--jars %s" % avro_jar
os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args])

import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.avro.functions

globs = pyspark.sql.connect.avro.functions.__dict__.copy()

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.avro.functions tests")
.remote("local[4]")
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.avro.functions,
globs=globs,
optionflags=doctest.ELLIPSIS
| doctest.NORMALIZE_WHITESPACE
| doctest.IGNORE_EXCEPTION_DETAIL,
)

globs["spark"].stop()

if failure_count:
sys.exit(-1)


if __name__ == "__main__":
_test()
16 changes: 16 additions & 0 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,22 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
return cast(FuncT, wrapped)


def try_remote_avro_functions(f: FuncT) -> FuncT:
"""Mark API supported from Spark Connect."""

@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:

if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
from pyspark.sql.connect.avro import functions

return getattr(functions, f.__name__)(*args, **kwargs)
else:
return f(*args, **kwargs)

return cast(FuncT, wrapped)


def try_remote_window(f: FuncT) -> FuncT:
"""Mark API supported from Spark Connect."""

Expand Down

0 comments on commit 5a56c17

Please sign in to comment.