diff --git a/core/src/main/scala/org/apache/spark/sql/hive/HivelessInternals.scala b/core/src/main/scala/org/apache/spark/sql/hive/HivelessInternals.scala index acc2679..a853a97 100644 --- a/core/src/main/scala/org/apache/spark/sql/hive/HivelessInternals.scala +++ b/core/src/main/scala/org/apache/spark/sql/hive/HivelessInternals.scala @@ -21,6 +21,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.spark.sql.types._ object HivelessInternals extends HiveInspectors with Serializable { + type GenericUDF = org.apache.spark.sql.hive.HiveGenericUDF + def toWritableInspector(dataType: DataType): ObjectInspector = dataType match { case ArrayType(tpe, _) => ObjectInspectorFactory.getStandardListObjectInspector(toWritableInspector(tpe)) diff --git a/core/src/main/scala/org/apache/spark/sql/hive/rules/syntax/package.scala b/core/src/main/scala/org/apache/spark/sql/hive/rules/syntax/package.scala new file mode 100644 index 0000000..994fe0e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/sql/hive/rules/syntax/package.scala @@ -0,0 +1,30 @@ +/* + * Copyright 2022 Azavea + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.rules + +import org.apache.spark.sql.hive.HivelessInternals.GenericUDF +import org.apache.spark.sql.catalyst.expressions.{And, Expression} + +import scala.reflect.{classTag, ClassTag} + +package object syntax extends Serializable { + implicit class HiveGenericUDFOps(val self: GenericUDF) extends AnyVal { + def of[T: ClassTag]: Boolean = self.funcWrapper.functionClassName == classTag[T].toString + } + + def AndList(list: List[Expression]): Expression = list.reduce(And) +} diff --git a/spatial-index/src/main/scala/org/apache/spark/sql/hive/hiveless/spatial/rules/SpatialFilterPushdownRules.scala b/spatial-index/src/main/scala/com/azavea/hiveless/spark/spatial/rules/SpatialFilterPushdownRules.scala similarity index 70% rename from spatial-index/src/main/scala/org/apache/spark/sql/hive/hiveless/spatial/rules/SpatialFilterPushdownRules.scala rename to spatial-index/src/main/scala/com/azavea/hiveless/spark/spatial/rules/SpatialFilterPushdownRules.scala index 95c222e..c3b4021 100644 --- a/spatial-index/src/main/scala/org/apache/spark/sql/hive/hiveless/spatial/rules/SpatialFilterPushdownRules.scala +++ b/spatial-index/src/main/scala/com/azavea/hiveless/spark/spatial/rules/SpatialFilterPushdownRules.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.hiveless.spatial.rules +package com.azavea.hiveless.spark.spatial.rules import com.azavea.hiveless.spatial._ import com.azavea.hiveless.spatial.index.ST_IntersectsExtent @@ -22,31 +22,34 @@ import com.azavea.hiveless.serializers.syntax._ import org.locationtech.jts.geom.Geometry import geotrellis.vector._ import cats.syntax.option._ +import org.apache.spark.sql.hive.HivelessInternals.GenericUDF +import org.apache.spark.sql.hive.rules.syntax._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.hive.HiveGenericUDF object SpatialFilterPushdownRules extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown { // HiveGenericUDF is a private[hive] case class - case Filter(condition: HiveGenericUDF, plan) if condition.of[ST_IntersectsExtent] => + case Filter(condition: GenericUDF, plan) if condition.of[ST_IntersectsExtent] => // extract bbox, snd val Seq(bboxExpr, geometryExpr) = condition.children // extract extent from the right val extent = geometryExpr.eval(null).convert[Geometry].extent // transform expression - val expr = List( - IsNotNull(bboxExpr), - GreaterThanOrEqual(GetStructField(bboxExpr, 0, "xmin".some), Literal(extent.xmin)), - GreaterThanOrEqual(GetStructField(bboxExpr, 1, "ymin".some), Literal(extent.ymin)), - LessThanOrEqual(GetStructField(bboxExpr, 2, "xmax".some), Literal(extent.xmax)), - LessThanOrEqual(GetStructField(bboxExpr, 3, "ymax".some), Literal(extent.ymax)) - ).and + val expr = AndList( + List( + IsNotNull(bboxExpr), + GreaterThanOrEqual(GetStructField(bboxExpr, 0, "xmin".some), Literal(extent.xmin)), + GreaterThanOrEqual(GetStructField(bboxExpr, 1, "ymin".some), Literal(extent.ymin)), + LessThanOrEqual(GetStructField(bboxExpr, 2, "xmax".some), Literal(extent.xmax)), + LessThanOrEqual(GetStructField(bboxExpr, 3, "ymax".some), Literal(extent.ymax)) + ) + ) Filter(expr, plan) } diff --git a/spatial-index/src/main/scala/org/apache/spark/sql/hive/hiveless/spatial/rules/package.scala b/spatial-index/src/main/scala/org/apache/spark/sql/hive/hiveless/spatial/rules/package.scala deleted file mode 100644 index 3d32d73..0000000 --- a/spatial-index/src/main/scala/org/apache/spark/sql/hive/hiveless/spatial/rules/package.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2022 Azavea - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.hiveless.spatial - -/* - * Copyright 2022 Azavea - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import org.apache.spark.sql.catalyst.expressions.{And, Expression} -import org.apache.spark.sql.hive.HiveGenericUDF - -import scala.reflect.{classTag, ClassTag} - -package object rules extends Serializable { - implicit class HiveGenericUDFOps(val self: HiveGenericUDF) extends AnyVal { - def of[T: ClassTag]: Boolean = self.funcWrapper.functionClassName == classTag[T].toString - } - - implicit class ListExpressionsOps(val self: List[Expression]) extends AnyVal { - def and: Expression = self.reduce(And) - } -} diff --git a/spatial-index/src/test/scala/com/azavea/hiveless/SpatialIndexHiveTestEnvironment.scala b/spatial-index/src/test/scala/com/azavea/hiveless/SpatialIndexHiveTestEnvironment.scala index c1af264..028fdda 100644 --- a/spatial-index/src/test/scala/com/azavea/hiveless/SpatialIndexHiveTestEnvironment.scala +++ b/spatial-index/src/test/scala/com/azavea/hiveless/SpatialIndexHiveTestEnvironment.scala @@ -16,7 +16,7 @@ package com.azavea.hiveless -import org.apache.spark.sql.hive.hiveless.spatial.rules.SpatialFilterPushdownRules +import com.azavea.hiveless.spark.spatial.rules.SpatialFilterPushdownRules import org.apache.spark.sql.{SQLContext, SparkSession} import org.scalatest.{BeforeAndAfterAll, Suite}