From bef11d89cab7364e8d9e65fffd96f163b9f4e1c3 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Fri, 29 Sep 2023 16:36:47 -0700 Subject: [PATCH] [SPARK-44913][SQL] DS V2 supports push down V2 UDF that has magic method ### What changes were proposed in this pull request? Right now we only support pushing down the V2 UDF that has not a magic method. Because the V2 UDF will be analyzed into the `ApplyFunctionExpression` which could be translated and pushed down. However, a V2 UDF that has the magic method will be analyzed into `StaticInvoke` or `Invoke` that can not be translated into V2 expression and then can not be pushed down to the data source. The magic method is suggested. ### Why are the changes needed? This PR adds the support of pushing down the V2 UDF that has a magic method. ### Does this PR introduce _any_ user-facing change? Yes, now the V2 UDF with the magic method could be pushed down. ### How was this patch tested? New UTs. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42612 from ConeyLiu/push-down-udf-with-magic. Lead-authored-by: Xianyang Liu Co-authored-by: xianyangliu Signed-off-by: Chao Sun --- .../expressions/V2ExpressionUtils.scala | 2 +- .../expressions/objects/objects.scala | 14 +++- .../expressions/stringExpressions.scala | 3 +- .../catalyst/util/V2ExpressionBuilder.scala | 23 ++++++ .../catalog/functions/JavaStrLen.java | 5 ++ .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 74 +++++++++++++++++++ 6 files changed, 118 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index 1d65d49443596..d23ba3867dfe8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -159,7 +159,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes, propagateNull = false, returnNullable = scalarFunc.isResultNullable, - isDeterministic = scalarFunc.isDeterministic) + isDeterministic = scalarFunc.isDeterministic, scalarFunction = Some(scalarFunc)) case Some(_) => val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ca7185aa428da..4cfc40f44795c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TernaryLike import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -270,6 +271,8 @@ object SerializerSupport { * non-null value. * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark * will not apply certain optimizations such as constant folding. + * @param scalarFunction the [[ScalarFunction]] object if this is calling the magic method of the + * [[ScalarFunction]] otherwise is unset. */ case class StaticInvoke( staticObject: Class[_], @@ -279,7 +282,8 @@ case class StaticInvoke( inputTypes: Seq[AbstractDataType] = Nil, propagateNull: Boolean = true, returnNullable: Boolean = true, - isDeterministic: Boolean = true) extends InvokeLike { + isDeterministic: Boolean = true, + scalarFunction: Option[ScalarFunction[_]] = None) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") val cls = if (staticObject.getName == objectName) { @@ -346,6 +350,14 @@ case class StaticInvoke( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(arguments = newChildren) + + override protected def stringArgs: Iterator[Any] = { + if (scalarFunction.nonEmpty) { + super.stringArgs + } else { + super.stringArgs.take(8) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 44ec403bf19af..6aa949b3344c3 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -523,7 +523,8 @@ trait StringBinaryPredicateExpressionBuilderBase extends ExpressionBuilder { object BinaryPredicate { def unapply(expr: Expression): Option[StaticInvoke] = expr match { - case s @ StaticInvoke(clz, _, "contains" | "startsWith" | "endsWith", Seq(_, _), _, _, _, _) + case s @ StaticInvoke( + clz, _, "contains" | "startsWith" | "endsWith", Seq(_, _), _, _, _, _, _) if clz == classOf[ByteArrayMethods] => Some(s) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 947a5e9f383f9..4a8965a6413fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} @@ -283,6 +285,27 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } else { None } + case Invoke(Literal(obj, _), functionName, _, arguments, _, _, _, _) => + obj match { + case function: ScalarFunction[_] if ScalarFunction.MAGIC_METHOD_NAME == functionName => + val argumentExpressions = arguments.flatMap(generateExpression(_)) + if (argumentExpressions.length == arguments.length) { + Some(new UserDefinedScalarFunc( + function.name(), function.canonicalName(), argumentExpressions.toArray[V2Expression])) + } else { + None + } + case _ => + None + } + case StaticInvoke(_, _, _, arguments, _, _, _, _, Some(scalarFunc)) => + val argumentExpressions = arguments.flatMap(generateExpression(_)) + if (argumentExpressions.length == arguments.length) { + Some(new UserDefinedScalarFunc( + scalarFunc.name(), scalarFunc.canonicalName(), argumentExpressions.toArray[V2Expression])) + } else { + None + } case _ => None } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java index dade2a113ef45..ad9746f820e40 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java @@ -73,6 +73,11 @@ public DataType resultType() { public String name() { return "strlen"; } + + @Override + public String canonicalName() { + return name(); + } } public static class JavaStrLenDefault extends JavaStrLenBase { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index ae0cfe17b11f5..bcb366bbdda11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -22,6 +22,8 @@ import java.util.Properties import scala.util.control.NonFatal +import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen.JavaStrLenStaticMagic + import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow @@ -38,6 +40,7 @@ import org.apache.spark.sql.functions.{abs, acos, asin, atan, atan2, avg, ceil, import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHelper { @@ -61,6 +64,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel canonicalName match { case "h2.char_length" => s"$funcName(${inputs.mkString(", ")})" + case "h2.char_length_magic" => + s"CHAR_LENGTH(${inputs.mkString(", ")})" + case "strlen" => + s"CHAR_LENGTH(${inputs.mkString(", ")})" case _ => super.visitUserDefinedScalarFunction(funcName, canonicalName, inputs) } } @@ -109,6 +116,18 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } + case object CharLengthWithMagicMethod extends ScalarFunction[Int] { + def invoke(str: UTF8String): Int = str.toString.length + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "CHAR_LENGTH_MAGIC" + override def canonicalName(): String = "h2.char_length_magic" + override def produceResult(input: InternalRow): Int = { + val s = input.getString(0) + s.length + } + } + override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.h2.url", url) @@ -194,6 +213,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } H2Dialect.registerFunction("my_avg", IntegralAverage) H2Dialect.registerFunction("my_strlen", StrLen(CharLength)) + H2Dialect.registerFunction("my_strlen_magic", StrLen(CharLengthWithMagicMethod)) + H2Dialect.registerFunction( + "my_strlen_static_magic", StrLen(new JavaStrLenStaticMagic())) } override def afterAll(): Unit = { @@ -1489,6 +1511,58 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } + test("scan with filter push-down with UDF that has magic method") { + JdbcDialects.unregisterDialect(H2Dialect) + try { + JdbcDialects.registerDialect(testH2Dialect) + val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen_magic(name) > 2") + checkFiltersRemoved(df1) + checkPushedInfo(df1, "PushedFilters: [CHAR_LENGTH_MAGIC(NAME) > 2],") + checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2))) + + val df2 = sql( + """ + |SELECT * + |FROM h2.test.people + |WHERE h2.my_strlen_magic(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2 + """.stripMargin) + checkFiltersRemoved(df2) + checkPushedInfo(df2, + "PushedFilters: [CHAR_LENGTH_MAGIC(CASE WHEN NAME = 'fred' " + + "THEN NAME ELSE 'abc' END) > 2],") + checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) + } finally { + JdbcDialects.unregisterDialect(testH2Dialect) + JdbcDialects.registerDialect(H2Dialect) + } + } + + test("scan with filter push-down with UDF that has static magic method") { + JdbcDialects.unregisterDialect(H2Dialect) + try { + JdbcDialects.registerDialect(testH2Dialect) + val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen_static_magic(name) > 2") + checkFiltersRemoved(df1) + checkPushedInfo(df1, "PushedFilters: [strlen(NAME) > 2],") + checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2))) + + val df2 = sql( + """ + |SELECT * + |FROM h2.test.people + |WHERE h2.my_strlen_static_magic(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2 + """.stripMargin) + checkFiltersRemoved(df2) + checkPushedInfo(df2, + "PushedFilters: [strlen(CASE WHEN NAME = 'fred' " + + "THEN NAME ELSE 'abc' END) > 2],") + checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2))) + } finally { + JdbcDialects.unregisterDialect(testH2Dialect) + JdbcDialects.registerDialect(H2Dialect) + } + } + test("scan with column pruning") { val df = spark.table("h2.test.people").select("id") checkSchemaNames(df, Seq("ID"))