From 5496d99241f1063766cf5954f754e870fbabcbe7 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Sep 2022 21:49:56 +0900 Subject: [PATCH] [SPARK-40433][SS][PYTHON] Add toJVMRow in PythonSQLUtils to convert pickled PySpark Row to JVM Row ### What changes were proposed in this pull request? This PR adds toJVMRow in PythonSQLUtils to convert pickled PySpark Row to JVM Row. Co-authored with HyukjinKwon . This is a breakdown PR of https://github.com/apache/spark/pull/37863. ### Why are the changes needed? This change will be leveraged in [SPARK-40434](https://issues.apache.org/jira/browse/SPARK-40434). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? N/A. We will make sure test suites are constructed via E2E manner under [SPARK-40431](https://issues.apache.org/jira/browse/SPARK-40431). Closes #37891 from HeartSaVioR/SPARK-40433. Lead-authored-by: Jungtaek Lim Co-authored-by: Hyukjin Kwon Signed-off-by: Jungtaek Lim --- .../spark/sql/api/python/PythonSQLUtils.scala | 42 +++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 2b74bcc38501a..c495b145dc678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -22,14 +22,15 @@ import java.net.Socket import java.nio.channels.Channels import java.util.Locale -import net.razorvine.pickle.Pickler +import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.Logging import org.apache.spark.security.SocketAuthServer import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -37,12 +38,29 @@ import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} private[sql] object PythonSQLUtils extends Logging { - private lazy val internalRowPickler = { + private def withInternalRowPickler(f: Pickler => Array[Byte]): Array[Byte] = { EvaluatePython.registerPicklers() - new Pickler(true, false) + val pickler = new Pickler(true, false) + val ret = try { + f(pickler) + } finally { + pickler.close() + } + ret + } + + private def withInternalRowUnpickler(f: Unpickler => Any): Any = { + EvaluatePython.registerPicklers() + val unpickler = new Unpickler + val ret = try { + f(unpickler) + } finally { + unpickler.close() + } + ret } def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText) @@ -94,8 +112,18 @@ private[sql] object PythonSQLUtils extends Logging { def toPyRow(row: Row): Array[Byte] = { assert(row.isInstanceOf[GenericRowWithSchema]) - internalRowPickler.dumps(EvaluatePython.toJava( - CatalystTypeConverters.convertToCatalyst(row), row.schema)) + withInternalRowPickler(_.dumps(EvaluatePython.toJava( + CatalystTypeConverters.convertToCatalyst(row), row.schema))) + } + + def toJVMRow( + arr: Array[Byte], + returnType: StructType, + deserializer: ExpressionEncoder.Deserializer[Row]): Row = { + val fromJava = EvaluatePython.makeFromJava(returnType) + val internalRow = + fromJava(withInternalRowUnpickler(_.loads(arr))).asInstanceOf[InternalRow] + deserializer(internalRow) } def castTimestampNTZToLong(c: Column): Column = Column(CastTimestampNTZToLong(c.expr))