Skip to content

Commit

Permalink
[SPARK-41434][CONNECT][PYTHON] Initial LambdaFunction implementation
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
There are 11 lambda functions, this PR adds the basic support for `LambdaFunction` and add the  `exists` function.

### Why are the changes needed?
for API coverage

### Does this PR introduce _any_ user-facing change?
yes, new API

### How was this patch tested?
added UT

Closes apache#39068 from zhengruifeng/connect_function_lambda.

Lead-authored-by: Ruifeng Zheng <[email protected]>
Co-authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
2 people authored and HyukjinKwon committed Dec 21, 2022
1 parent 801e079 commit e23983a
Show file tree
Hide file tree
Showing 7 changed files with 302 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ message Expression {
Cast cast = 7;
UnresolvedRegex unresolved_regex = 8;
SortOrder sort_order = 9;
LambdaFunction lambda_function = 10;
}

// SortOrder is used to specify the data ordering, it is normally used in Sort and Window.
Expand Down Expand Up @@ -191,4 +192,15 @@ message Expression {
// (Optional) Alias metadata expressed as a JSON map.
optional string metadata = 3;
}

message LambdaFunction {
// (Required) The lambda function.
//
// The function body should use 'UnresolvedAttribute' as arguments, the sever side will
// replace 'UnresolvedAttribute' with 'UnresolvedNamedLambdaVariable'.
Expression function = 1;

// (Required) Function variable names. Must contains 1 ~ 3 variables.
repeated string arguments = 2;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,8 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Expression.ExprTypeCase.UNRESOLVED_REGEX =>
transformUnresolvedRegex(exp.getUnresolvedRegex)
case proto.Expression.ExprTypeCase.SORT_ORDER => transformSortOrder(exp.getSortOrder)
case proto.Expression.ExprTypeCase.LAMBDA_FUNCTION =>
transformLambdaFunction(exp.getLambdaFunction)
case _ =>
throw InvalidPlanInput(
s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported")
Expand Down Expand Up @@ -558,6 +560,38 @@ class SparkConnectPlanner(session: SparkSession) {
}
}

/**
* Translates a LambdaFunction from proto to the Catalyst expression.
*/
private def transformLambdaFunction(lambda: proto.Expression.LambdaFunction): LambdaFunction = {
if (lambda.getArgumentsCount == 0 || lambda.getArgumentsCount > 3) {
throw InvalidPlanInput(
"LambdaFunction requires 1 ~ 3 arguments, " +
s"but got ${lambda.getArgumentsCount} ones!")
}

val variableNames = lambda.getArgumentsList.asScala.toSeq

// generate unique variable names: Map(x -> x_0, y -> y_1)
val newVariables = variableNames.map { name =>
val uniqueName = UnresolvedNamedLambdaVariable.freshVarName(name)
(name, UnresolvedNamedLambdaVariable(Seq(uniqueName)))
}.toMap

val function = transformExpression(lambda.getFunction)

// rewrite function by replacing UnresolvedAttribute with UnresolvedNamedLambdaVariable
val newFunction = function transform {
case variable: UnresolvedAttribute
if variable.nameParts.length == 1 &&
newVariables.contains(variable.nameParts.head) =>
newVariables(variable.nameParts.head)
}

// LambdaFunction["x_0, y_1 -> x_0 < y_1", ["x_0", "y_1"]]
LambdaFunction(function = newFunction, arguments = variableNames.map(newVariables))
}

/**
* For some reason, not all functions are registered in 'FunctionRegistry'. For a unregistered
* function, we can still wrap it under the proto 'UnresolvedFunction', and then resolve it in
Expand Down
29 changes: 29 additions & 0 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,35 @@ def __repr__(self) -> str:
return f"({self._col} ({self._data_type}))"


class LambdaFunction(Expression):
def __init__(
self,
function: Expression,
arguments: Sequence[str],
) -> None:
super().__init__()

assert isinstance(function, Expression)

assert (
isinstance(arguments, list)
and len(arguments) > 0
and all(isinstance(arg, str) for arg in arguments)
)

self._function = function
self._arguments = arguments

def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
fun = proto.Expression()
fun.lambda_function.function.CopyFrom(self._function.to_plan(session))
fun.lambda_function.arguments.extend(self._arguments)
return fun

def __repr__(self) -> str:
return f"(LambdaFunction({str(self._function)}, {', '.join(self._arguments)})"


class Column:
"""
A column in a DataFrame. Column can refer to different things based on the
Expand Down
149 changes: 112 additions & 37 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import inspect

from pyspark.sql.connect.column import (
Column,
CaseWhen,
Expand All @@ -22,9 +25,10 @@
ColumnReference,
UnresolvedFunction,
SQLExpression,
LambdaFunction,
)

from typing import Any, TYPE_CHECKING, Union, List, overload, Optional, Tuple
from typing import Any, TYPE_CHECKING, Union, List, overload, Optional, Tuple, Callable, ValuesView

if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
Expand Down Expand Up @@ -80,6 +84,78 @@ def _invoke_binary_math_function(name: str, col1: Any, col2: Any) -> Column:
return _invoke_function(name, *_cols)


def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]:
signature = inspect.signature(f)
parameters = signature.parameters.values()

# We should exclude functions that use, variable args and keyword argument
# names, as well as keyword only args.
supported_parameter_types = {
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
}

# Validate that the function arity is between 1 and 3.
if not (1 <= len(parameters) <= 3):
raise ValueError(
"f should take between 1 and 3 arguments, but provided function takes {}".format(
len(parameters)
)
)

# Verify that all arguments can be used as positional arguments.
if not all(p.kind in supported_parameter_types for p in parameters):
raise ValueError("All arguments of f must be usable as POSITIONAL arguments")

return parameters


def _create_lambda(f: Callable) -> LambdaFunction:
"""
Create `o.a.s.sql.expressions.LambdaFunction` corresponding
to transformation described by f
:param f: A Python of one of the following forms:
- (Column) -> Column: ...
- (Column, Column) -> Column: ...
- (Column, Column, Column) -> Column: ...
"""
parameters = _get_lambda_parameters(f)

arg_names = ["x", "y", "z"][: len(parameters)]
arg_cols = [column(arg) for arg in arg_names]

result = f(*arg_cols)

if not isinstance(result, Column):
raise ValueError(f"Callable {f} should return Column, got {type(result)}")

return LambdaFunction(result._expr, arg_names)


def _invoke_higher_order_function(
name: str,
cols: List["ColumnOrName"],
funs: List[Callable],
) -> Column:
"""
Invokes expression identified by name,
(relative to ```org.apache.spark.sql.catalyst.expressions``)
and wraps the result with Column (first Scala one, then Python).
:param name: Name of the expression
:param cols: a list of columns
:param funs: a list of((*Column) -> Column functions.
:return: a Column
"""
assert len(funs) == 1
_cols = [_to_col(c) for c in cols]
_funs = [_create_lambda(f) for f in funs]

return _invoke_function(name, *_cols, *_funs)


# Normal Functions


Expand Down Expand Up @@ -3862,42 +3938,41 @@ def element_at(col: "ColumnOrName", extraction: Any) -> Column:
return _invoke_function("element_at", _to_col(col), lit(extraction))


# TODO(SPARK-41434): need to support LambdaFunction Expression first
# def exists(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column:
# """
# Returns whether a predicate holds for one or more elements in the array.
#
# .. versionadded:: 3.1.0
#
# Parameters
# ----------
# col : :class:`~pyspark.sql.Column` or str
# name of column or expression
# f : function
# ``(x: Column) -> Column: ...`` returning the Boolean expression.
# Can use methods of :class:`~pyspark.sql.Column`, functions defined in
# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
# Python ``UserDefinedFunctions`` are not supported
# (`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
#
# Returns
# -------
# :class:`~pyspark.sql.Column`
# True if "any" element of an array evaluates to True when passed as an argument to
# given function and False otherwise.
#
# Examples
# --------
# >>> df = spark.createDataFrame([(1, [1, 2, 3, 4]), (2, [3, -1, 0])],("key", "values"))
# >>> df.select(exists("values", lambda x: x < 0).alias("any_negative")).show()
# +------------+
# |any_negative|
# +------------+
# | false|
# | true|
# +------------+
# """
# return _invoke_higher_order_function("ArrayExists", [col], [f])
def exists(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column:
"""
Returns whether a predicate holds for one or more elements in the array.
.. versionadded:: 3.4.0
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
name of column or expression
f : function
``(x: Column) -> Column: ...`` returning the Boolean expression.
Can use methods of :class:`~pyspark.sql.Column`, functions defined in
:py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
Python ``UserDefinedFunctions`` are not supported
(`SPARK-27052 <https://issues.apache.org/jira/browse/SPARK-27052>`__).
Returns
-------
:class:`~pyspark.sql.Column`
True if "any" element of an array evaluates to True when passed as an argument to
given function and False otherwise.
Examples
--------
>>> df = spark.createDataFrame([(1, [1, 2, 3, 4]), (2, [3, -1, 0])],("key", "values"))
>>> df.select(exists("values", lambda x: x < 0).alias("any_negative")).show()
+------------+
|any_negative|
+------------+
| false|
| true|
+------------+
"""
return _invoke_higher_order_function("exists", [col], [f])


def explode(col: "ColumnOrName") -> Column:
Expand Down
Loading

0 comments on commit e23983a

Please sign in to comment.