Skip to content

Commit

Permalink
[SPARK-42929] make mapInPandas / mapInArrow support "is_barrier"
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

make mapInPandas / mapInArrow support "is_barrier"

### Why are the changes needed?

feature parity.

### Does this PR introduce _any_ user-facing change?

Yes.

### How was this patch tested?

Manually:

`bin/pyspark --remote local`:

```
from pyspark.sql.functions import pandas_udf
df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
def filter_func(iterator):
    for pdf in iterator:
        yield pdf[pdf.id == 1]
df.mapInPandas(filter_func, df.schema,  is_barrier=True).collect()

def filter_func(iterator):
    for batch in iterator:
        pdf = batch.to_pandas()
        yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1])

df.mapInArrow(filter_func, df.schema, is_barrier=True).collect()
```

Closes apache#40559 from WeichenXu123/spark-connect-barrier-mode.

Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Mar 27, 2023
1 parent c55c7ea commit 2a1ac07
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,9 @@ message MapPartitions {

// (Required) Input user-defined function.
CommonInlineUserDefinedFunction func = 2;

// (Optional) isBarrier.
optional bool is_barrier = 3;
}

message GroupMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,19 +484,20 @@ class SparkConnectPlanner(val session: SparkSession) {
private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = {
val commonUdf = rel.getFunc
val pythonUdf = transformPythonUDF(commonUdf)
val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false
pythonUdf.evalType match {
case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF =>
logical.MapInPandas(
pythonUdf,
pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
transformRelation(rel.getInput),
false)
isBarrier)
case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
logical.PythonMapInArrow(
pythonUdf,
pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
transformRelation(rel.getInput),
false)
isBarrier)
case _ =>
throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported")
}
Expand Down
21 changes: 16 additions & 5 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,7 @@ def _map_partitions(
func: "PandasMapIterFunction",
schema: Union[StructType, str],
evalType: int,
is_barrier: bool,
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction

Expand All @@ -1636,21 +1637,31 @@ def _map_partitions(
)

return DataFrame.withPlan(
plan.MapPartitions(child=self._plan, function=udf_obj, cols=self.columns),
plan.MapPartitions(
child=self._plan, function=udf_obj, cols=self.columns, is_barrier=is_barrier
),
session=self._session,
)

def mapInPandas(
self, func: "PandasMapIterFunction", schema: Union[StructType, str]
self,
func: "PandasMapIterFunction",
schema: Union[StructType, str],
is_barrier: bool = False,
) -> "DataFrame":
return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
return self._map_partitions(
func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, is_barrier
)

mapInPandas.__doc__ = PySparkDataFrame.mapInPandas.__doc__

def mapInArrow(
self, func: "ArrowMapIterFunction", schema: Union[StructType, str]
self,
func: "ArrowMapIterFunction",
schema: Union[StructType, str],
is_barrier: bool = False,
) -> "DataFrame":
return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF)
return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, is_barrier)

mapInArrow.__doc__ = PySparkDataFrame.mapInArrow.__doc__

Expand Down
8 changes: 7 additions & 1 deletion python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1912,17 +1912,23 @@ class MapPartitions(LogicalPlan):
"""Logical plan object for a mapPartitions-equivalent API: mapInPandas, mapInArrow."""

def __init__(
self, child: Optional["LogicalPlan"], function: "UserDefinedFunction", cols: List[str]
self,
child: Optional["LogicalPlan"],
function: "UserDefinedFunction",
cols: List[str],
is_barrier: bool,
) -> None:
super().__init__(child)

self._func = function._build_common_inline_user_defined_function(*cols)
self._is_barrier = is_barrier

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.map_partitions.input.CopyFrom(self._child.plan(session))
plan.map_partitions.func.CopyFrom(self._func.to_plan_udf(session))
plan.map_partitions.is_barrier = self._is_barrier
return plan


Expand Down
Loading

0 comments on commit 2a1ac07

Please sign in to comment.