Skip to content

Commit

Permalink
[SPARK-44913][SQL] DS V2 supports push down V2 UDF that has magic method
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Right now we only support pushing down the V2 UDF that has not a magic method. Because the V2 UDF will be analyzed into the `ApplyFunctionExpression` which could be translated and pushed down. However, a V2 UDF that has the magic method will be analyzed into `StaticInvoke` or `Invoke` that can not be translated into V2 expression and then can not be pushed down to the data source. The magic method is suggested.

### Why are the changes needed?

This PR adds the support of pushing down the V2 UDF that has a magic method.

### Does this PR introduce _any_ user-facing change?

Yes, now the V2 UDF with the magic method could be pushed down.

### How was this patch tested?

New UTs.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #42612 from ConeyLiu/push-down-udf-with-magic.

Lead-authored-by: Xianyang Liu <[email protected]>
Co-authored-by: xianyangliu <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
  • Loading branch information
2 people authored and sunchao committed Sep 29, 2023
1 parent 4863dec commit bef11d8
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
propagateNull = false, returnNullable = scalarFunc.isResultNullable,
isDeterministic = scalarFunc.isDeterministic)
isDeterministic = scalarFunc.isDeterministic, scalarFunction = Some(scalarFunc))
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -270,6 +271,8 @@ object SerializerSupport {
* non-null value.
* @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark
* will not apply certain optimizations such as constant folding.
* @param scalarFunction the [[ScalarFunction]] object if this is calling the magic method of the
* [[ScalarFunction]] otherwise is unset.
*/
case class StaticInvoke(
staticObject: Class[_],
Expand All @@ -279,7 +282,8 @@ case class StaticInvoke(
inputTypes: Seq[AbstractDataType] = Nil,
propagateNull: Boolean = true,
returnNullable: Boolean = true,
isDeterministic: Boolean = true) extends InvokeLike {
isDeterministic: Boolean = true,
scalarFunction: Option[ScalarFunction[_]] = None) extends InvokeLike {

val objectName = staticObject.getName.stripSuffix("$")
val cls = if (staticObject.getName == objectName) {
Expand Down Expand Up @@ -346,6 +350,14 @@ case class StaticInvoke(

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(arguments = newChildren)

override protected def stringArgs: Iterator[Any] = {
if (scalarFunction.nonEmpty) {
super.stringArgs
} else {
super.stringArgs.take(8)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,8 @@ trait StringBinaryPredicateExpressionBuilderBase extends ExpressionBuilder {

object BinaryPredicate {
def unapply(expr: Expression): Option[StaticInvoke] = expr match {
case s @ StaticInvoke(clz, _, "contains" | "startsWith" | "endsWith", Seq(_, _), _, _, _, _)
case s @ StaticInvoke(
clz, _, "contains" | "startsWith" | "endsWith", Seq(_, _), _, _, _, _, _)
if clz == classOf[ByteArrayMethods] => Some(s)
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete}
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
Expand Down Expand Up @@ -283,6 +285,27 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
} else {
None
}
case Invoke(Literal(obj, _), functionName, _, arguments, _, _, _, _) =>
obj match {
case function: ScalarFunction[_] if ScalarFunction.MAGIC_METHOD_NAME == functionName =>
val argumentExpressions = arguments.flatMap(generateExpression(_))
if (argumentExpressions.length == arguments.length) {
Some(new UserDefinedScalarFunc(
function.name(), function.canonicalName(), argumentExpressions.toArray[V2Expression]))
} else {
None
}
case _ =>
None
}
case StaticInvoke(_, _, _, arguments, _, _, _, _, Some(scalarFunc)) =>
val argumentExpressions = arguments.flatMap(generateExpression(_))
if (argumentExpressions.length == arguments.length) {
Some(new UserDefinedScalarFunc(
scalarFunc.name(), scalarFunc.canonicalName(), argumentExpressions.toArray[V2Expression]))
} else {
None
}
case _ => None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ public DataType resultType() {
public String name() {
return "strlen";
}

@Override
public String canonicalName() {
return name();
}
}

public static class JavaStrLenDefault extends JavaStrLenBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.util.Properties

import scala.util.control.NonFatal

import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen.JavaStrLenStaticMagic

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{AnalysisException, DataFrame, ExplainSuiteHelper, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -38,6 +40,7 @@ import org.apache.spark.sql.functions.{abs, acos, asin, atan, atan2, avg, ceil,
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHelper {
Expand All @@ -61,6 +64,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
canonicalName match {
case "h2.char_length" =>
s"$funcName(${inputs.mkString(", ")})"
case "h2.char_length_magic" =>
s"CHAR_LENGTH(${inputs.mkString(", ")})"
case "strlen" =>
s"CHAR_LENGTH(${inputs.mkString(", ")})"
case _ => super.visitUserDefinedScalarFunction(funcName, canonicalName, inputs)
}
}
Expand Down Expand Up @@ -109,6 +116,18 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
}
}

case object CharLengthWithMagicMethod extends ScalarFunction[Int] {
def invoke(str: UTF8String): Int = str.toString.length
override def inputTypes(): Array[DataType] = Array(StringType)
override def resultType(): DataType = IntegerType
override def name(): String = "CHAR_LENGTH_MAGIC"
override def canonicalName(): String = "h2.char_length_magic"
override def produceResult(input: InternalRow): Int = {
val s = input.getString(0)
s.length
}
}

override def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.h2.url", url)
Expand Down Expand Up @@ -194,6 +213,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
}
H2Dialect.registerFunction("my_avg", IntegralAverage)
H2Dialect.registerFunction("my_strlen", StrLen(CharLength))
H2Dialect.registerFunction("my_strlen_magic", StrLen(CharLengthWithMagicMethod))
H2Dialect.registerFunction(
"my_strlen_static_magic", StrLen(new JavaStrLenStaticMagic()))
}

override def afterAll(): Unit = {
Expand Down Expand Up @@ -1489,6 +1511,58 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
}
}

test("scan with filter push-down with UDF that has magic method") {
JdbcDialects.unregisterDialect(H2Dialect)
try {
JdbcDialects.registerDialect(testH2Dialect)
val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen_magic(name) > 2")
checkFiltersRemoved(df1)
checkPushedInfo(df1, "PushedFilters: [CHAR_LENGTH_MAGIC(NAME) > 2],")
checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2)))

val df2 = sql(
"""
|SELECT *
|FROM h2.test.people
|WHERE h2.my_strlen_magic(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2
""".stripMargin)
checkFiltersRemoved(df2)
checkPushedInfo(df2,
"PushedFilters: [CHAR_LENGTH_MAGIC(CASE WHEN NAME = 'fred' " +
"THEN NAME ELSE 'abc' END) > 2],")
checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2)))
} finally {
JdbcDialects.unregisterDialect(testH2Dialect)
JdbcDialects.registerDialect(H2Dialect)
}
}

test("scan with filter push-down with UDF that has static magic method") {
JdbcDialects.unregisterDialect(H2Dialect)
try {
JdbcDialects.registerDialect(testH2Dialect)
val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen_static_magic(name) > 2")
checkFiltersRemoved(df1)
checkPushedInfo(df1, "PushedFilters: [strlen(NAME) > 2],")
checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2)))

val df2 = sql(
"""
|SELECT *
|FROM h2.test.people
|WHERE h2.my_strlen_static_magic(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2
""".stripMargin)
checkFiltersRemoved(df2)
checkPushedInfo(df2,
"PushedFilters: [strlen(CASE WHEN NAME = 'fred' " +
"THEN NAME ELSE 'abc' END) > 2],")
checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2)))
} finally {
JdbcDialects.unregisterDialect(testH2Dialect)
JdbcDialects.registerDialect(H2Dialect)
}
}

test("scan with column pruning") {
val df = spark.table("h2.test.people").select("id")
checkSchemaNames(df, Seq("ID"))
Expand Down

0 comments on commit bef11d8

Please sign in to comment.