diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 4a86e6f9d57bf..6224f79b4335d 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -38,6 +38,7 @@ message Expression { Cast cast = 7; UnresolvedRegex unresolved_regex = 8; SortOrder sort_order = 9; + LambdaFunction lambda_function = 10; } // SortOrder is used to specify the data ordering, it is normally used in Sort and Window. @@ -191,4 +192,15 @@ message Expression { // (Optional) Alias metadata expressed as a JSON map. optional string metadata = 3; } + + message LambdaFunction { + // (Required) The lambda function. + // + // The function body should use 'UnresolvedAttribute' as arguments, the sever side will + // replace 'UnresolvedAttribute' with 'UnresolvedNamedLambdaVariable'. + Expression function = 1; + + // (Required) Function variable names. Must contains 1 ~ 3 variables. + repeated string arguments = 2; + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index cad6c3c5c613d..e281e2ea3f46a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -501,6 +501,8 @@ class SparkConnectPlanner(session: SparkSession) { case proto.Expression.ExprTypeCase.UNRESOLVED_REGEX => transformUnresolvedRegex(exp.getUnresolvedRegex) case proto.Expression.ExprTypeCase.SORT_ORDER => transformSortOrder(exp.getSortOrder) + case proto.Expression.ExprTypeCase.LAMBDA_FUNCTION => + transformLambdaFunction(exp.getLambdaFunction) case _ => throw InvalidPlanInput( s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported") @@ -558,6 +560,38 @@ class SparkConnectPlanner(session: SparkSession) { } } + /** + * Translates a LambdaFunction from proto to the Catalyst expression. + */ + private def transformLambdaFunction(lambda: proto.Expression.LambdaFunction): LambdaFunction = { + if (lambda.getArgumentsCount == 0 || lambda.getArgumentsCount > 3) { + throw InvalidPlanInput( + "LambdaFunction requires 1 ~ 3 arguments, " + + s"but got ${lambda.getArgumentsCount} ones!") + } + + val variableNames = lambda.getArgumentsList.asScala.toSeq + + // generate unique variable names: Map(x -> x_0, y -> y_1) + val newVariables = variableNames.map { name => + val uniqueName = UnresolvedNamedLambdaVariable.freshVarName(name) + (name, UnresolvedNamedLambdaVariable(Seq(uniqueName))) + }.toMap + + val function = transformExpression(lambda.getFunction) + + // rewrite function by replacing UnresolvedAttribute with UnresolvedNamedLambdaVariable + val newFunction = function transform { + case variable: UnresolvedAttribute + if variable.nameParts.length == 1 && + newVariables.contains(variable.nameParts.head) => + newVariables(variable.nameParts.head) + } + + // LambdaFunction["x_0, y_1 -> x_0 < y_1", ["x_0", "y_1"]] + LambdaFunction(function = newFunction, arguments = variableNames.map(newVariables)) + } + /** * For some reason, not all functions are registered in 'FunctionRegistry'. For a unregistered * function, we can still wrap it under the proto 'UnresolvedFunction', and then resolve it in diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 82c5d4c34630c..adce180eaa3b3 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -539,6 +539,35 @@ def __repr__(self) -> str: return f"({self._col} ({self._data_type}))" +class LambdaFunction(Expression): + def __init__( + self, + function: Expression, + arguments: Sequence[str], + ) -> None: + super().__init__() + + assert isinstance(function, Expression) + + assert ( + isinstance(arguments, list) + and len(arguments) > 0 + and all(isinstance(arg, str) for arg in arguments) + ) + + self._function = function + self._arguments = arguments + + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: + fun = proto.Expression() + fun.lambda_function.function.CopyFrom(self._function.to_plan(session)) + fun.lambda_function.arguments.extend(self._arguments) + return fun + + def __repr__(self) -> str: + return f"(LambdaFunction({str(self._function)}, {', '.join(self._arguments)})" + + class Column: """ A column in a DataFrame. Column can refer to different things based on the diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 8e3e7692df711..b6683a373e002 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +import inspect + from pyspark.sql.connect.column import ( Column, CaseWhen, @@ -22,9 +25,10 @@ ColumnReference, UnresolvedFunction, SQLExpression, + LambdaFunction, ) -from typing import Any, TYPE_CHECKING, Union, List, overload, Optional, Tuple +from typing import Any, TYPE_CHECKING, Union, List, overload, Optional, Tuple, Callable, ValuesView if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName @@ -80,6 +84,78 @@ def _invoke_binary_math_function(name: str, col1: Any, col2: Any) -> Column: return _invoke_function(name, *_cols) +def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]: + signature = inspect.signature(f) + parameters = signature.parameters.values() + + # We should exclude functions that use, variable args and keyword argument + # names, as well as keyword only args. + supported_parameter_types = { + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.POSITIONAL_ONLY, + } + + # Validate that the function arity is between 1 and 3. + if not (1 <= len(parameters) <= 3): + raise ValueError( + "f should take between 1 and 3 arguments, but provided function takes {}".format( + len(parameters) + ) + ) + + # Verify that all arguments can be used as positional arguments. + if not all(p.kind in supported_parameter_types for p in parameters): + raise ValueError("All arguments of f must be usable as POSITIONAL arguments") + + return parameters + + +def _create_lambda(f: Callable) -> LambdaFunction: + """ + Create `o.a.s.sql.expressions.LambdaFunction` corresponding + to transformation described by f + + :param f: A Python of one of the following forms: + - (Column) -> Column: ... + - (Column, Column) -> Column: ... + - (Column, Column, Column) -> Column: ... + """ + parameters = _get_lambda_parameters(f) + + arg_names = ["x", "y", "z"][: len(parameters)] + arg_cols = [column(arg) for arg in arg_names] + + result = f(*arg_cols) + + if not isinstance(result, Column): + raise ValueError(f"Callable {f} should return Column, got {type(result)}") + + return LambdaFunction(result._expr, arg_names) + + +def _invoke_higher_order_function( + name: str, + cols: List["ColumnOrName"], + funs: List[Callable], +) -> Column: + """ + Invokes expression identified by name, + (relative to ```org.apache.spark.sql.catalyst.expressions``) + and wraps the result with Column (first Scala one, then Python). + + :param name: Name of the expression + :param cols: a list of columns + :param funs: a list of((*Column) -> Column functions. + + :return: a Column + """ + assert len(funs) == 1 + _cols = [_to_col(c) for c in cols] + _funs = [_create_lambda(f) for f in funs] + + return _invoke_function(name, *_cols, *_funs) + + # Normal Functions @@ -3862,42 +3938,41 @@ def element_at(col: "ColumnOrName", extraction: Any) -> Column: return _invoke_function("element_at", _to_col(col), lit(extraction)) -# TODO(SPARK-41434): need to support LambdaFunction Expression first -# def exists(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: -# """ -# Returns whether a predicate holds for one or more elements in the array. -# -# .. versionadded:: 3.1.0 -# -# Parameters -# ---------- -# col : :class:`~pyspark.sql.Column` or str -# name of column or expression -# f : function -# ``(x: Column) -> Column: ...`` returning the Boolean expression. -# Can use methods of :class:`~pyspark.sql.Column`, functions defined in -# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. -# Python ``UserDefinedFunctions`` are not supported -# (`SPARK-27052 `__). -# -# Returns -# ------- -# :class:`~pyspark.sql.Column` -# True if "any" element of an array evaluates to True when passed as an argument to -# given function and False otherwise. -# -# Examples -# -------- -# >>> df = spark.createDataFrame([(1, [1, 2, 3, 4]), (2, [3, -1, 0])],("key", "values")) -# >>> df.select(exists("values", lambda x: x < 0).alias("any_negative")).show() -# +------------+ -# |any_negative| -# +------------+ -# | false| -# | true| -# +------------+ -# """ -# return _invoke_higher_order_function("ArrayExists", [col], [f]) +def exists(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: + """ + Returns whether a predicate holds for one or more elements in the array. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + f : function + ``(x: Column) -> Column: ...`` returning the Boolean expression. + Can use methods of :class:`~pyspark.sql.Column`, functions defined in + :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. + Python ``UserDefinedFunctions`` are not supported + (`SPARK-27052 `__). + + Returns + ------- + :class:`~pyspark.sql.Column` + True if "any" element of an array evaluates to True when passed as an argument to + given function and False otherwise. + + Examples + -------- + >>> df = spark.createDataFrame([(1, [1, 2, 3, 4]), (2, [3, -1, 0])],("key", "values")) + >>> df.select(exists("values", lambda x: x < 0).alias("any_negative")).show() + +------------+ + |any_negative| + +------------+ + | false| + | true| + +------------+ + """ + return _invoke_higher_order_function("exists", [col], [f]) def explode(col: "ColumnOrName") -> Column: diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index a9d1e96808cbf..2d6d17049042e 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xe1\x15\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x1a\xed\x02\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"L\n\rSortDirection\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x00\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x01"9\n\x0cNullOrdering\x12\x14\n\x10SORT_NULLS_FIRST\x10\x00\x12\x13\n\x0fSORT_NULLS_LAST\x10\x01\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\xe3\x07\n\x07Literal\x12\x14\n\x04null\x18\x01 \x01(\x08H\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12\x38\n\ntyped_null\x18\x16 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\ttypedNull\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicrosecondsB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a(\n\x0eUnresolvedStar\x12\x16\n\x06target\x18\x01 \x03(\tR\x06target\x1a,\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadataB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\x9d\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x1a\xed\x02\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"L\n\rSortDirection\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x00\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x01"9\n\x0cNullOrdering\x12\x14\n\x10SORT_NULLS_FIRST\x10\x00\x12\x13\n\x0fSORT_NULLS_LAST\x10\x01\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\xe3\x07\n\x07Literal\x12\x14\n\x04null\x18\x01 \x01(\x08H\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12\x38\n\ntyped_null\x18\x16 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\ttypedNull\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicrosecondsB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a(\n\x0eUnresolvedStar\x12\x16\n\x06target\x18\x01 \x03(\tR\x06target\x1a,\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x65\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12\x1c\n\targuments\x18\x02 \x03(\tR\targumentsB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -49,6 +49,7 @@ _EXPRESSION_UNRESOLVEDSTAR = _EXPRESSION.nested_types_by_name["UnresolvedStar"] _EXPRESSION_UNRESOLVEDREGEX = _EXPRESSION.nested_types_by_name["UnresolvedRegex"] _EXPRESSION_ALIAS = _EXPRESSION.nested_types_by_name["Alias"] +_EXPRESSION_LAMBDAFUNCTION = _EXPRESSION.nested_types_by_name["LambdaFunction"] _EXPRESSION_SORTORDER_SORTDIRECTION = _EXPRESSION_SORTORDER.enum_types_by_name["SortDirection"] _EXPRESSION_SORTORDER_NULLORDERING = _EXPRESSION_SORTORDER.enum_types_by_name["NullOrdering"] Expression = _reflection.GeneratedProtocolMessageType( @@ -154,6 +155,15 @@ # @@protoc_insertion_point(class_scope:spark.connect.Expression.Alias) }, ), + "LambdaFunction": _reflection.GeneratedProtocolMessageType( + "LambdaFunction", + (_message.Message,), + { + "DESCRIPTOR": _EXPRESSION_LAMBDAFUNCTION, + "__module__": "spark.connect.expressions_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Expression.LambdaFunction) + }, + ), "DESCRIPTOR": _EXPRESSION, "__module__": "spark.connect.expressions_pb2" # @@protoc_insertion_point(class_scope:spark.connect.Expression) @@ -171,37 +181,40 @@ _sym_db.RegisterMessage(Expression.UnresolvedStar) _sym_db.RegisterMessage(Expression.UnresolvedRegex) _sym_db.RegisterMessage(Expression.Alias) +_sym_db.RegisterMessage(Expression.LambdaFunction) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 78 - _EXPRESSION._serialized_end = 2863 - _EXPRESSION_SORTORDER._serialized_start = 798 - _EXPRESSION_SORTORDER._serialized_end = 1163 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 1028 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 1104 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 1106 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 1163 - _EXPRESSION_CAST._serialized_start = 1166 - _EXPRESSION_CAST._serialized_end = 1311 - _EXPRESSION_LITERAL._serialized_start = 1314 - _EXPRESSION_LITERAL._serialized_end = 2309 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 2076 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 2193 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 2195 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 2293 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2311 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2381 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2384 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2588 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2590 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2640 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2642 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2682 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 2684 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 2728 - _EXPRESSION_ALIAS._serialized_start = 2730 - _EXPRESSION_ALIAS._serialized_end = 2850 + _EXPRESSION._serialized_end = 3051 + _EXPRESSION_SORTORDER._serialized_start = 883 + _EXPRESSION_SORTORDER._serialized_end = 1248 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 1113 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 1189 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 1191 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 1248 + _EXPRESSION_CAST._serialized_start = 1251 + _EXPRESSION_CAST._serialized_end = 1396 + _EXPRESSION_LITERAL._serialized_start = 1399 + _EXPRESSION_LITERAL._serialized_end = 2394 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 2161 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 2278 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 2280 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 2378 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2396 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2466 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2469 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2673 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2675 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2725 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2727 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2767 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 2769 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 2813 + _EXPRESSION_ALIAS._serialized_start = 2815 + _EXPRESSION_ALIAS._serialized_end = 2935 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 2937 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 3038 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 30c8f8c8c4184..c75a09ebeb367 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -624,6 +624,39 @@ class Expression(google.protobuf.message.Message): self, oneof_group: typing_extensions.Literal["_metadata", b"_metadata"] ) -> typing_extensions.Literal["metadata"] | None: ... + class LambdaFunction(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + FUNCTION_FIELD_NUMBER: builtins.int + ARGUMENTS_FIELD_NUMBER: builtins.int + @property + def function(self) -> global___Expression: + """(Required) The lambda function. + + The function body should use 'UnresolvedAttribute' as arguments, the sever side will + replace 'UnresolvedAttribute' with 'UnresolvedNamedLambdaVariable'. + """ + @property + def arguments( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """(Required) Function variable names. Must contains 1 ~ 3 variables.""" + def __init__( + self, + *, + function: global___Expression | None = ..., + arguments: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["function", b"function"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "arguments", b"arguments", "function", b"function" + ], + ) -> None: ... + LITERAL_FIELD_NUMBER: builtins.int UNRESOLVED_ATTRIBUTE_FIELD_NUMBER: builtins.int UNRESOLVED_FUNCTION_FIELD_NUMBER: builtins.int @@ -633,6 +666,7 @@ class Expression(google.protobuf.message.Message): CAST_FIELD_NUMBER: builtins.int UNRESOLVED_REGEX_FIELD_NUMBER: builtins.int SORT_ORDER_FIELD_NUMBER: builtins.int + LAMBDA_FUNCTION_FIELD_NUMBER: builtins.int @property def literal(self) -> global___Expression.Literal: ... @property @@ -651,6 +685,8 @@ class Expression(google.protobuf.message.Message): def unresolved_regex(self) -> global___Expression.UnresolvedRegex: ... @property def sort_order(self) -> global___Expression.SortOrder: ... + @property + def lambda_function(self) -> global___Expression.LambdaFunction: ... def __init__( self, *, @@ -663,6 +699,7 @@ class Expression(google.protobuf.message.Message): cast: global___Expression.Cast | None = ..., unresolved_regex: global___Expression.UnresolvedRegex | None = ..., sort_order: global___Expression.SortOrder | None = ..., + lambda_function: global___Expression.LambdaFunction | None = ..., ) -> None: ... def HasField( self, @@ -675,6 +712,8 @@ class Expression(google.protobuf.message.Message): b"expr_type", "expression_string", b"expression_string", + "lambda_function", + b"lambda_function", "literal", b"literal", "sort_order", @@ -700,6 +739,8 @@ class Expression(google.protobuf.message.Message): b"expr_type", "expression_string", b"expression_string", + "lambda_function", + b"lambda_function", "literal", b"literal", "sort_order", @@ -726,6 +767,7 @@ class Expression(google.protobuf.message.Message): "cast", "unresolved_regex", "sort_order", + "lambda_function", ] | None: ... global___Expression = Expression diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index e290b07b207ff..ca366473ce04d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -916,6 +916,38 @@ def test_generator_functions(self): sdf.select(SF.posexplode_outer("d"), "c").toPandas(), ) + def test_lambda_functions(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT * FROM VALUES + (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3), 1, 2, 'a'), + (ARRAY('x', NULL), NULL, ARRAY(1, 3), 3, 4, 'x'), + (NULL, ARRAY(-1, -2, -3), Array(), 5, 6, NULL) + AS tab(a, b, c, d, e, f) + """ + # +---------+------------+------------+---+---+----+ + # | a| b| c| d| e| f| + # +---------+------------+------------+---+---+----+ + # | [a, ab]| [1, 2, 3]|[1, null, 3]| 1| 2| a| + # |[x, null]| null| [1, 3]| 3| 4| x| + # | null|[-1, -2, -3]| []| 5| 6|null| + # +---------+------------+------------+---+---+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # test exists + self.assert_eq( + cdf.select(CF.exists(cdf.b, lambda x: x < 0)).toPandas(), + sdf.select(SF.exists(sdf.b, lambda x: x < 0)).toPandas(), + ) + self.assert_eq( + cdf.select(CF.exists("a", lambda x: CF.isnull(x))).toPandas(), + sdf.select(SF.exists("a", lambda x: SF.isnull(x))).toPandas(), + ) + def test_csv_functions(self): from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF