diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 2b74bcc38501a..c495b145dc678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -22,14 +22,15 @@ import java.net.Socket import java.nio.channels.Channels import java.util.Locale -import net.razorvine.pickle.Pickler +import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.Logging import org.apache.spark.security.SocketAuthServer import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -37,12 +38,29 @@ import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} private[sql] object PythonSQLUtils extends Logging { - private lazy val internalRowPickler = { + private def withInternalRowPickler(f: Pickler => Array[Byte]): Array[Byte] = { EvaluatePython.registerPicklers() - new Pickler(true, false) + val pickler = new Pickler(true, false) + val ret = try { + f(pickler) + } finally { + pickler.close() + } + ret + } + + private def withInternalRowUnpickler(f: Unpickler => Any): Any = { + EvaluatePython.registerPicklers() + val unpickler = new Unpickler + val ret = try { + f(unpickler) + } finally { + unpickler.close() + } + ret } def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText) @@ -94,8 +112,18 @@ private[sql] object PythonSQLUtils extends Logging { def toPyRow(row: Row): Array[Byte] = { assert(row.isInstanceOf[GenericRowWithSchema]) - internalRowPickler.dumps(EvaluatePython.toJava( - CatalystTypeConverters.convertToCatalyst(row), row.schema)) + withInternalRowPickler(_.dumps(EvaluatePython.toJava( + CatalystTypeConverters.convertToCatalyst(row), row.schema))) + } + + def toJVMRow( + arr: Array[Byte], + returnType: StructType, + deserializer: ExpressionEncoder.Deserializer[Row]): Row = { + val fromJava = EvaluatePython.makeFromJava(returnType) + val internalRow = + fromJava(withInternalRowUnpickler(_.loads(arr))).asInstanceOf[InternalRow] + deserializer(internalRow) } def castTimestampNTZToLong(c: Column): Column = Column(CastTimestampNTZToLong(c.expr))