From b4800d3d57af6b077577af59be9cd9324bd3f56f Mon Sep 17 00:00:00 2001 From: Feng Zhang Date: Fri, 13 Sep 2024 17:28:39 -0700 Subject: [PATCH] [SEDONA-648] Throw unsupported operation exception when ST_KNN is used as UDF (#1587) --- .../org/apache/sedona/common/Predicates.java | 5 ++- docs/api/sql/NearestNeighbourSearching.md | 36 +++++++++++++++++-- .../strategy/join/JoinQueryDetector.scala | 22 ++++++++++-- .../org/apache/sedona/sql/KnnJoinSuite.scala | 20 +++++++---- 4 files changed, 69 insertions(+), 14 deletions(-) diff --git a/common/src/main/java/org/apache/sedona/common/Predicates.java b/common/src/main/java/org/apache/sedona/common/Predicates.java index 9d27a184db..1db1f92828 100644 --- a/common/src/main/java/org/apache/sedona/common/Predicates.java +++ b/common/src/main/java/org/apache/sedona/common/Predicates.java @@ -96,12 +96,11 @@ public static boolean relateMatch(String matrix1, String matrix2) { } public static boolean knn(Geometry leftGeometry, Geometry rightGeometry, int k) { - return knn(leftGeometry, rightGeometry, k, false); + throw new UnsupportedOperationException("KNN predicate is not supported"); } public static boolean knn( Geometry leftGeometry, Geometry rightGeometry, int k, boolean useSpheroid) { - // This should only be used as a test predicate used with extra join condition - return true; + throw new UnsupportedOperationException("KNN predicate is not supported"); } } diff --git a/docs/api/sql/NearestNeighbourSearching.md b/docs/api/sql/NearestNeighbourSearching.md index bc65777cbd..cf1fce91d3 100644 --- a/docs/api/sql/NearestNeighbourSearching.md +++ b/docs/api/sql/NearestNeighbourSearching.md @@ -19,7 +19,7 @@ In case there are ties in the distance, the result will include all the tied geo spark.sedona.join.knn.includeTieBreakers=true ``` -Filter Pushdown Considerations: +### Filter Pushdown Considerations: When using ST_KNN with filters applied to the resulting DataFrame, some of these filters may be pushed down to the object side of the kNN join. This means the filters will be applied to the object side reader before the kNN join is executed. If you want the filters to be applied after the kNN join, ensure that you first materialize the kNN join results and then apply the filters. @@ -43,7 +43,39 @@ CACHE TABLE knnResult; SELECT * FROM knnResult WHERE condition; ``` -SQL Example +### Handling SQL-Defined Tables in ST_KNN Joins: + +When creating DataFrames from hard-coded SQL select statements in Sedona, and later using them in `ST_KNN` joins, Sedona may attempt to optimize the query in a way that bypasses the intended kNN join logic. Specifically, if you create DataFrames with hard-coded SQL, such as: + +```scala +val df1 = sedona.sql("SELECT ST_Point(0.0, 0.0) as geom1") +val df2 = sedona.sql("SELECT ST_Point(0.0, 0.0) as geom2") + +val df = df1.join(df2, expr("ST_KNN(geom1, geom2, 1)")) +``` + +Sedona may optimize the join to a form like this: + +```sql +SELECT ST_KNN(ST_Point(0.0, 0.0), ST_Point(0.0, 0.0), 1) +``` + +As a result, the ST_KNN function is handled as a User-Defined Function (UDF) instead of a proper join operation, preventing Sedona from initiating the kNN join execution path. Unlike typical UDFs, the ST_KNN function operates on multiple rows across DataFrames, not just individual rows. When this occurs, the query fails with an UnsupportedOperationException, indicating that the KNN predicate is not supported. + +Workaround: + +To prevent Spark's optimization from bypassing the kNN join logic, the DataFrames created with hard-coded SQL select statements must be materialized before performing the join. By caching the DataFrames, you can instruct Spark to avoid this undesired optimization: + +```scala +val df1 = sedona.sql("SELECT ST_Point(0.0, 0.0) as geom1").cache() +val df2 = sedona.sql("SELECT ST_Point(0.0, 0.0) as geom2").cache() + +val df = df1.join(df2, expr("ST_KNN(geom1, geom2, 1)")) +``` + +Materializing the DataFrames with .cache() ensures that the correct kNN join path is followed in the Spark logical plan and prevents the optimization that would treat ST_KNN as a simple UDF. + +### SQL Example Suppose we have two tables `QUERIES` and `OBJECTS` with the following data: diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala index 6fe4d5838b..825855b88c 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.sedona_sql.UDT.RasterUDT -import org.apache.spark.sql.sedona_sql.expressions._ +import org.apache.spark.sql.sedona_sql.expressions.{ST_KNN, _} import org.apache.spark.sql.sedona_sql.expressions.raster._ import org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates import org.apache.spark.sql.{SparkSession, Strategy} @@ -602,7 +602,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { spatialPredicate = null, isGeography, condition, - extraCondition) :: Nil + extractExtraKNNJoinCondition(condition)) :: Nil } private def planDistanceJoin( @@ -664,6 +664,24 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { } } + private def extractExtraKNNJoinCondition(condition: Expression): Option[Expression] = { + condition match { + case and: And => + // Check both left and right sides for ST_KNN or ST_AKNN + if (and.left.isInstanceOf[ST_KNN]) { + Some(and.right) + } else if (and.right.isInstanceOf[ST_KNN]) { + Some(and.left) + } else { + None + } + case _: ST_KNN => + None + case _ => + Some(condition) + } + } + private def planBroadcastJoin( left: LogicalPlan, right: LogicalPlan, diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala index 53e57b9ed4..1d6119d02d 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala @@ -70,7 +70,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { df, numNeighbors = 3, useApproximate = false, - expressionSize = 5, + expressionSize = 4, isGeography = true, mustInclude = "") } @@ -83,7 +83,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { df, numNeighbors = 3, useApproximate = true, - expressionSize = 5, + expressionSize = 4, isGeography = false, mustInclude = "") } @@ -98,7 +98,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { df, numNeighbors = 3, useApproximate = true, - expressionSize = 5, + expressionSize = 4, isGeography = false, mustInclude = "") } @@ -112,7 +112,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { df, numNeighbors = 3, useApproximate = true, - expressionSize = 5, + expressionSize = 4, isGeography = false, mustInclude = "as int) <= 88))") } @@ -124,7 +124,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { df, numNeighbors = 3, useApproximate = true, - expressionSize = 5, + expressionSize = 4, isGeography = false, mustInclude = "= point))") } @@ -136,7 +136,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { df, numNeighbors = 3, useApproximate = true, - expressionSize = 5, + expressionSize = 4, isGeography = false, mustInclude = "= point))") } @@ -148,7 +148,7 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { df, numNeighbors = 3, useApproximate = true, - expressionSize = 5, + expressionSize = 4, isGeography = false, mustInclude = "") } @@ -216,6 +216,12 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { resultAll.length should be(8) // 2 queries (filtered out 1) and 4 neighbors each resultAll.mkString should be("[2,1][2,5][2,11][2,15][3,3][3,9][3,13][3,19]") } + + it("Should throw KNN predicate is not supported exception") { + intercept[Exception] { + sparkSession.sql("SELECT ST_KNN(ST_Point(0.0, 0.0), ST_Point(0.0, 0.0), 1)").show() + } + } } describe("KNN spatial join SQLs should be executed correctly with complex join conditions") {