Skip to content

Commit

Permalink
[SPARK-50238][PYTHON] Add Variant Support in PySpark UDFs/UDTFs/UDAFs
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds support for the Variant type in PySpark UDFs/UDTFs/UDAFs. Support is added in both modes - arrow and pickle - and support is also added in pandas UDFs.

### Why are the changes needed?

After this change, users will be able to use the new Variant data type with UDFs, which is currently prohibited.

### Does this PR introduce _any_ user-facing change?

Yes, users should now be able to use Variants with Python UDFs.

### How was this patch tested?

Unit tests in all scenarios - arrow, pickle and pandas

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48770 from harshmotw-db/harsh-motwani_data/variant_udf_3.

Authored-by: Harsh Motwani <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
harshmotw-db authored and HyukjinKwon committed Nov 13, 2024
1 parent 4f4eb22 commit 4002a53
Show file tree
Hide file tree
Showing 16 changed files with 595 additions and 279 deletions.
10 changes: 0 additions & 10 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1045,16 +1045,6 @@
"The input of <functionName> can't be <dataType> type data."
]
},
"UNSUPPORTED_UDF_INPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an input data type."
]
},
"UNSUPPORTED_UDF_OUTPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an output data type."
]
},
"VALUE_OUT_OF_RANGE" : {
"message" : [
"The <exprName> must be between <valueRange> (current value = <currentValue>)."
Expand Down
20 changes: 18 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from pyspark.sql.pandas.types import (
from_arrow_type,
is_variant,
to_arrow_type,
_create_converter_from_pandas,
_create_converter_to_pandas,
Expand Down Expand Up @@ -420,7 +421,14 @@ def __init__(
def arrow_to_pandas(self, arrow_column):
import pyarrow.types as types

if self._df_for_struct and types.is_struct(arrow_column.type):
# If the arrow type is struct, return a pandas dataframe where the fields of the struct
# correspond to columns in the DataFrame. However, if the arrow struct is actually a
# Variant, which is an atomic type, treat it as a non-struct arrow type.
if (
self._df_for_struct
and types.is_struct(arrow_column.type)
and not is_variant(arrow_column.type)
):
import pandas as pd

series = [
Expand Down Expand Up @@ -505,7 +513,15 @@ def _create_batch(self, series):

arrs = []
for s, t in series:
if self._struct_in_pandas == "dict" and t is not None and pa.types.is_struct(t):
# Variants are represented in arrow as structs with additional metadata (checked by
# is_variant). If the data type is Variant, return a VariantVal atomic type instead of
# a dict of two binary values.
if (
self._struct_in_pandas == "dict"
and t is not None
and pa.types.is_struct(t)
and not is_variant(t)
):
# A pandas UDF should return pd.DataFrame when the return type is a struct type.
# If it returns a pd.Series, it should throw an error.
if not isinstance(s, pd.DataFrame):
Expand Down
30 changes: 29 additions & 1 deletion python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def to_arrow_type(
elif type(dt) == VariantType:
fields = [
pa.field("value", pa.binary(), nullable=False),
pa.field("metadata", pa.binary(), nullable=False),
# The metadata field is tagged so we can identify that the arrow struct actually
# represents a variant.
pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}),
]
arrow_type = pa.struct(fields)
else:
Expand Down Expand Up @@ -221,6 +223,22 @@ def to_arrow_schema(
return pa.schema(fields)


def is_variant(at: "pa.DataType") -> bool:
"""Check if a PyArrow struct data type represents a variant"""
import pyarrow.types as types

assert types.is_struct(at)

return any(
(
field.name == "metadata"
and b"variant" in field.metadata
and field.metadata[b"variant"] == b"true"
)
for field in at
) and any(field.name == "value" for field in at)


def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType:
"""Convert pyarrow type to Spark data type."""
import pyarrow.types as types
Expand Down Expand Up @@ -280,6 +298,8 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
from_arrow_type(at.item_type, prefer_timestamp_ntz),
)
elif types.is_struct(at):
if is_variant(at):
return VariantType()
return StructType(
[
StructField(
Expand Down Expand Up @@ -1295,6 +1315,14 @@ def convert_udt(value: Any) -> Any:

return convert_udt

elif isinstance(dt, VariantType):

def convert_variant(variant: Any) -> Any:
assert isinstance(variant, VariantVal)
return {"value": variant.value, "metadata": variant.metadata}

return convert_variant

return None

conv = _converter(data_type)
Expand Down
41 changes: 40 additions & 1 deletion python/pyspark/sql/tests/pandas/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
from typing import cast

from pyspark.sql.functions import udf, pandas_udf, PandasUDFType, assert_true, lit
from pyspark.sql.types import DoubleType, StructType, StructField, LongType, DayTimeIntervalType
from pyspark.sql.types import (
DoubleType,
StructType,
StructField,
LongType,
DayTimeIntervalType,
VariantType,
)
from pyspark.errors import ParseException, PythonException, PySparkTypeError
from pyspark.util import PythonEvalType
from pyspark.testing.sqlutils import (
Expand All @@ -42,33 +49,65 @@ def test_pandas_udf_basic(self):
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, VariantType())
self.assertEqual(udf.returnType, VariantType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, VariantType(), PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, VariantType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(
lambda x: x, StructType([StructField("v", DoubleType())]), PandasUDFType.GROUPED_MAP
)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(
lambda x: x, StructType([StructField("v", VariantType())]), PandasUDFType.GROUPED_MAP
)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

def test_pandas_udf_basic_with_return_type_string(self):
udf = pandas_udf(lambda x: x, "double", PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "variant", PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, VariantType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v double", PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v variant", PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v double", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v variant", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, returnType="v double", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(
lambda x: x, returnType="v variant", functionType=PandasUDFType.GROUPED_MAP
)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

def test_pandas_udf_decorator(self):
@pandas_udf(DoubleType())
def foo(x):
Expand Down
Loading

0 comments on commit 4002a53

Please sign in to comment.