Skip to content

Commit

Permalink
Progress towards #288
Browse files Browse the repository at this point in the history
- Added support for registering UDTs to Plugin API
- Internal plugins now correctly loaded, even if no external plugins defined
- Sedona Plugin now handles initializing spark, PNG support, and UDT initialization (moved from InitSpark, SparkPrimitive, and SparkSchema)
- SparkPrimitive no longer handles geometry/raster codecs, instead directing codec requests for unknown UDTs to Plugin API
- SparkSchema no longer decodes geometry/raster UDTs, instead hooking the Plugin API
- SharedTestResources now correctly loads internal plugins
  • Loading branch information
okennedy committed Aug 8, 2024
1 parent 45e3312 commit 4d7a49d
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 68 deletions.
35 changes: 34 additions & 1 deletion vizier/backend/src/info/vizierdb/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.io.FileNotFoundException
import com.typesafe.scalalogging.LazyLogging
import scala.collection.mutable
import info.vizierdb.util.ClassLoaderUtils
import org.apache.spark.sql.types.UserDefinedType

case class Plugin(
name: String,
Expand All @@ -37,9 +38,16 @@ object Plugin
extends LazyLogging
{

case class VizierUDT[T >: Null](
shortName: String,
dataType: UserDefinedType[T],
encode: Any => JsValue,
decode: JsValue => Any
)

val loaded = mutable.Map[String, Plugin]()
val jars = mutable.Buffer[URL]()

private val udt = mutable.Map[String, VizierUDT[_ >: Null]]()

implicit val pluginFormat: Format[Plugin] = Json.format

Expand Down Expand Up @@ -109,4 +117,29 @@ object Plugin
}

def loadedJars = jars.toSeq


def registerUDT(shortName: String, dataType: UserDefinedType[_ >: Null], encode: Any => JsValue, decode: JsValue => Any) =
{
val spec = VizierUDT(shortName, dataType, encode, decode)
udt(shortName) = spec
}

object PluginUDTByType
{
def unapply(base: UserDefinedType[_]): Option[VizierUDT[_]] =
{
for(p <- udt.values){
if(p.dataType.equals(base)) { return Some(p) }
}
return None
}
}
object PluginUDTByName
{
def unapply(base: String): Option[VizierUDT[_]] =
{
udt.get(base)
}
}
}
4 changes: 2 additions & 2 deletions vizier/backend/src/info/vizierdb/Vizier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ object Vizier
var mainClassLoader: ClassLoader =
Thread.currentThread().getContextClassLoader()
val internalPlugins = Seq[Plugin](
VizierSedona.Plugin
VizierSedona.plugin
)

def initSQLite(db: String = "Vizier.db") =
Expand Down Expand Up @@ -234,9 +234,9 @@ object Vizier
)

// Set up plugins
loadInternalPlugins()
if(!config.plugins.isEmpty){
println("Loading plugins...")
loadInternalPlugins()
loadPlugins(config.plugins)
}

Expand Down
22 changes: 1 addition & 21 deletions vizier/backend/src/info/vizierdb/spark/InitSpark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.linalg.{Vector => OldVector}
import info.vizierdb.Vizier
import java.io.File
import org.apache.sedona.spark.SedonaContext
import org.apache.spark.sql.catalyst.FunctionIdentifier
import com.typesafe.scalalogging.LazyLogging

Expand Down Expand Up @@ -113,26 +112,7 @@ object InitSpark

spark.udf.register("vector_to_array", vectorToArrayUdf)
spark.udf.register("array_to_vector", arrayToVectorUdf)
spark = SedonaContext.create(spark)

// Rejigger Sedona's AsPNG (if present) to dump out ImageUDT-typed data
{
val registry =
sparkSession.sessionState
.functionRegistry
val as_png = FunctionIdentifier("RS_AsPNG")
( registry.lookupFunction(as_png),
registry.lookupFunctionBuilder(as_png)
) match {
case (Some(info), Some(builder)) =>
registry.dropFunction(as_png)
registry.registerFunction(as_png, info,
(args) => SedonaPNGWrapper(builder(args))
)
case (_,_) =>
logger.warn("Can not override Sedona PNG class; Sedona's RS_AsPNG's output will not display properly in spreadsheets")
}
}

return spark
}

Expand Down
40 changes: 7 additions & 33 deletions vizier/backend/src/info/vizierdb/spark/SparkPrimitive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,13 @@ import java.sql.{ Date, Timestamp }
import scala.util.matching.Regex
import scala.collection.mutable.ArraySeq
import org.apache.spark.sql.types.UDTRegistration
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.locationtech.jts.geom.Geometry
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.Row
import org.apache.sedona.core.formatMapper.FormatMapper
import org.apache.sedona.common.enums.FileDataSplitter
import org.apache.sedona.common.raster.RasterOutputs
import org.apache.sedona.common.raster.RasterConstructors
import org.apache.sedona.common.raster.{ Serde => SedonaRasterSerde }
import org.geotools.coverage.grid.GridCoverage2D
import java.awt.image.BufferedImage
import java.nio.charset.StandardCharsets
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.sedona.sql.utils.GeometrySerializer
import com.typesafe.scalalogging.LazyLogging
import org.apache.spark.sql.types.{ UserDefinedType, BinaryType }
import info.vizierdb.VizierException
Expand All @@ -48,6 +39,7 @@ import info.vizierdb.serialized.MLVector
import info.vizierdb.serializers.mlvectorFormat
import info.vizierdb.util.JsonUtils
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
import info.vizierdb.Plugin.PluginUDTByType

object SparkPrimitive
extends Object
Expand Down Expand Up @@ -120,9 +112,6 @@ object SparkPrimitive
case _ => throw new IllegalArgumentException(s"Invalid Timestamp: '$timestamp'")
}

lazy val geometryFormatMapper =
new FormatMapper(FileDataSplitter.WKT, false)

def encodeStruct(k: Any, t: StructType): JsObject =
{
JsObject(
Expand Down Expand Up @@ -152,19 +141,10 @@ object SparkPrimitive
case (_, null) => JsNull
case (StringType, _) => JsString(k.toString)
case (BinaryType, _) => JsString(base64Encode(k.asInstanceOf[Array[Byte]]))
case (ut:UserDefinedType[_], _) =>
case (PluginUDTByType(p), _) => p.encode(k)
case (ut:UserDefinedType[_], _) =>
{
// GeometryUDT is broken: https://issues.apache.org/jira/browse/SEDONA-89?filter=-2
// so we need to do a manual comparison here.

if(t.isInstanceOf[GeometryUDT]){
k match {
case geom:Geometry => JsString(geom.toText)
case enc:ArrayData => JsString(GeometrySerializer.deserialize(enc.toByteArray()).toText)
}
} else if(t.isInstanceOf[RasterUDT]){
JsString(base64Encode(SedonaRasterSerde.serialize(k.asInstanceOf[GridCoverage2D])))
} else if(t.isInstanceOf[ImageUDT]){
if(t.isInstanceOf[ImageUDT]){
JsString(base64Encode(ImageUDT.serialize(k.asInstanceOf[BufferedImage]).asInstanceOf[Array[Byte]]))
} else if(t.getClass().getName() == "org.apache.spark.ml.linalg.VectorUDT"
|| t.getClass().getName() == "org.apache.spark.mllib.linalg.VectorUDT") {
Expand Down Expand Up @@ -268,16 +248,10 @@ object SparkPrimitive
case (_, DateType) => decodeDate(k.as[String])
case (_, TimestampType) => decodeTimestamp(k.as[String])
case (_, BinaryType) => base64Decode(k.as[String])
case (_, ut:UserDefinedType[_]) =>
case (_, PluginUDTByType(p)) => p.decode(k)
case (_, ut:UserDefinedType[_]) =>
{
// GeometryUDT is broken: https://issues.apache.org/jira/browse/SEDONA-89?filter=-2
// so we need to do a manual comparison here.

if(t.isInstanceOf[GeometryUDT]){
geometryFormatMapper.readGeometry(k.as[String]) // parse as WKT
} else if(t.isInstanceOf[RasterUDT]){
SedonaRasterSerde.deserialize(base64Decode(k.as[String]))
} else if(t.isInstanceOf[ImageUDT]){
if(t.isInstanceOf[ImageUDT]){
ImageUDT.deserialize(base64Decode(k.as[String]))
} else if(t.getClass().getName() == "org.apache.spark.ml.linalg.VectorUDT"
|| t.getClass().getName() == "org.apache.spark.mllib.linalg.VectorUDT") {
Expand Down
13 changes: 6 additions & 7 deletions vizier/backend/src/info/vizierdb/spark/SparkSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ package info.vizierdb.spark
import play.api.libs.json._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types._
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.types.UDTRegistration
import info.vizierdb.spark.udt.ImageUDT
import org.apache.spark.mllib.linalg.VectorUDT
import info.vizierdb.util.StringUtils
import info.vizierdb.Vizier
import scala.collection.mutable
import info.vizierdb.Plugin.PluginUDTByName
import info.vizierdb.Plugin.PluginUDTByType

object SparkSchema {
def apply(df: DataFrame): Seq[StructField] =
Expand Down Expand Up @@ -86,8 +87,6 @@ object SparkSchema {
case "varchar" => StringType
case "int" => IntegerType
case "real" => DoubleType
case "geometry" => GeometryUDT
case "raster" => RasterUDT
case "vector" => vectorSingleton
case "binary" => BinaryType
case "image/png" => ImageUDT
Expand All @@ -106,6 +105,7 @@ object SparkSchema {
(map \ "nulls").as[Boolean]
)
}
case PluginUDTByName(p) => p.dataType
case _ =>
DataType.fromJson("\""+t+"\"")
}
Expand All @@ -121,11 +121,10 @@ object SparkSchema {
// Something changed in a recent version of spark/scala and now
// any subclass of UserDefinedType seems to match any other
// subclass of the same. Need to use an explicit isInstanceOf
case _ if t.isInstanceOf[GeometryUDT] => "geometry"
case _ if t.isInstanceOf[RasterUDT] => "raster"
case _ if t.isInstanceOf[VectorUDT] => "vector"
case _ if t.isInstanceOf[ImageUDT] => "image/png"
case _ if t.isInstanceOf[UserDefinedType[_]] =>
case PluginUDTByType(p) => p.shortName
case udt:UserDefinedType[_] =>
{
// TODO: We need cleaner UDT handling. Convention in most UDT-based systems is to
// adopt a UDT object with the same name as the actual UDT. Drop down to the base
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ object SharedTestResources
// Normal initialization
Vizier.initSQLite()
Vizier.initSpark()
Vizier.loadInternalPlugins()
Geocode.init(
geocoders = Seq(
TestCaseGeocoder("GOOGLE"),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,83 @@
package info.vizierdb.plugin.sedona

import org.apache.spark.sql.SparkSession
import com.typesafe.scalalogging.LazyLogging
import org.apache.spark.sql.catalyst.FunctionIdentifier
import info.vizierdb.spark.SedonaPNGWrapper
import org.apache.sedona.spark.SedonaContext
import info.vizierdb.spark.SparkSchema
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import play.api.libs.json.JsString
import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.sql.catalyst.util.ArrayData
import info.vizierdb.Plugin
import org.locationtech.jts.geom.Geometry
import org.apache.sedona.core.formatMapper.FormatMapper
import org.apache.sedona.common.enums.FileDataSplitter
import org.apache.sedona.common.raster.RasterOutputs
import org.apache.sedona.common.raster.RasterConstructors
import org.apache.sedona.common.raster.{ Serde => SedonaRasterSerde }
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import java.util.Base64
import org.geotools.coverage.grid.GridCoverage2D

object VizierSedona
extends LazyLogging
{
lazy val geometryFormatMapper =
new FormatMapper(FileDataSplitter.WKT, false)

def base64Encode(b: Array[Byte]): String =
Base64.getEncoder().encodeToString(b)

def base64Decode(b: String): Array[Byte] =
Base64.getDecoder().decode(b)

object VizierSedona {

def init(spark: SparkSession): Unit =
{
println("Sedona Initializing")
// Sedona setup hooks
SedonaContext.create(spark)

// Sedona UDTs
Plugin.registerUDT("geometry", GeometryUDT,
{
case v: Geometry => JsString(v.toText)
case v: ArrayData => JsString(GeometrySerializer.deserialize(v.toByteArray).toText),
},
{
j => geometryFormatMapper.readGeometry(j.as[String])
}
)
Plugin.registerUDT("raster", RasterUDT,
{
case k: GridCoverage2D => JsString(base64Encode(SedonaRasterSerde.serialize(k)))
},
{
k => SedonaRasterSerde.deserialize(base64Decode(k.as[String]))
}
)

// Rejigger Sedona's AsPNG (if present) to dump out ImageUDT-typed data
{
val registry =
spark.sessionState
.functionRegistry
val as_png = FunctionIdentifier("RS_AsPNG")
( registry.lookupFunction(as_png),
registry.lookupFunctionBuilder(as_png)
) match {
case (Some(info), Some(builder)) =>
registry.dropFunction(as_png)
registry.registerFunction(as_png, info,
(args) => SedonaPNGWrapper(builder(args))
)
case (_,_) =>
logger.warn("Can not override Sedona PNG class; Sedona's RS_AsPNG's output will not display properly in spreadsheets")
}
}
}

object Plugin extends info.vizierdb.Plugin(
object plugin extends Plugin(
name = "Sedona",
schema_version = 1,
plugin_class = VizierSedona.getClass.getName(),
Expand Down

0 comments on commit 4d7a49d

Please sign in to comment.