Skip to content

Commit

Permalink
Expose HiveGenericUDF through HivelessInternals (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
pomadchin authored Apr 9, 2022
1 parent 0eb3dab commit 059761d
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.spark.sql.types._

object HivelessInternals extends HiveInspectors with Serializable {
type GenericUDF = org.apache.spark.sql.hive.HiveGenericUDF

def toWritableInspector(dataType: DataType): ObjectInspector = dataType match {
case ArrayType(tpe, _) =>
ObjectInspectorFactory.getStandardListObjectInspector(toWritableInspector(tpe))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright 2022 Azavea
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.rules

import org.apache.spark.sql.hive.HivelessInternals.GenericUDF
import org.apache.spark.sql.catalyst.expressions.{And, Expression}

import scala.reflect.{classTag, ClassTag}

package object syntax extends Serializable {
implicit class HiveGenericUDFOps(val self: GenericUDF) extends AnyVal {
def of[T: ClassTag]: Boolean = self.funcWrapper.functionClassName == classTag[T].toString
}

def AndList(list: List[Expression]): Expression = list.reduce(And)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,42 @@
* limitations under the License.
*/

package org.apache.spark.sql.hive.hiveless.spatial.rules
package com.azavea.hiveless.spark.spatial.rules

import com.azavea.hiveless.spatial._
import com.azavea.hiveless.spatial.index.ST_IntersectsExtent
import com.azavea.hiveless.serializers.syntax._
import org.locationtech.jts.geom.Geometry
import geotrellis.vector._
import cats.syntax.option._
import org.apache.spark.sql.hive.HivelessInternals.GenericUDF
import org.apache.spark.sql.hive.rules.syntax._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.hive.HiveGenericUDF

object SpatialFilterPushdownRules extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan =
plan.transformDown {
// HiveGenericUDF is a private[hive] case class
case Filter(condition: HiveGenericUDF, plan) if condition.of[ST_IntersectsExtent] =>
case Filter(condition: GenericUDF, plan) if condition.of[ST_IntersectsExtent] =>
// extract bbox, snd
val Seq(bboxExpr, geometryExpr) = condition.children
// extract extent from the right
val extent = geometryExpr.eval(null).convert[Geometry].extent

// transform expression
val expr = List(
IsNotNull(bboxExpr),
GreaterThanOrEqual(GetStructField(bboxExpr, 0, "xmin".some), Literal(extent.xmin)),
GreaterThanOrEqual(GetStructField(bboxExpr, 1, "ymin".some), Literal(extent.ymin)),
LessThanOrEqual(GetStructField(bboxExpr, 2, "xmax".some), Literal(extent.xmax)),
LessThanOrEqual(GetStructField(bboxExpr, 3, "ymax".some), Literal(extent.ymax))
).and
val expr = AndList(
List(
IsNotNull(bboxExpr),
GreaterThanOrEqual(GetStructField(bboxExpr, 0, "xmin".some), Literal(extent.xmin)),
GreaterThanOrEqual(GetStructField(bboxExpr, 1, "ymin".some), Literal(extent.ymin)),
LessThanOrEqual(GetStructField(bboxExpr, 2, "xmax".some), Literal(extent.xmax)),
LessThanOrEqual(GetStructField(bboxExpr, 3, "ymax".some), Literal(extent.ymax))
)
)

Filter(expr, plan)
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.azavea.hiveless

import org.apache.spark.sql.hive.hiveless.spatial.rules.SpatialFilterPushdownRules
import com.azavea.hiveless.spark.spatial.rules.SpatialFilterPushdownRules
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.scalatest.{BeforeAndAfterAll, Suite}

Expand Down

0 comments on commit 059761d

Please sign in to comment.