Skip to content

Commit

Permalink
Add bloom filter, skipping strategy and aggregator
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Feb 2, 2024
1 parent 6f70884 commit 8c6d35a
Show file tree
Hide file tree
Showing 8 changed files with 336 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, MIN_MAX, PARTITION, VALUE_SET}
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy
import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy
import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy
Expand Down Expand Up @@ -59,6 +60,11 @@ object FlintSparkIndexFactory {
params = parameters)
case MIN_MAX =>
MinMaxSkippingStrategy(columnName = columnName, columnType = columnType)
case BLOOM_FILTER =>
BloomFilterSkippingStrategy(
columnName = columnName,
columnType = columnType,
params = parameters)
case other =>
throw new IllegalStateException(s"Unknown skipping strategy: $other")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.opensearch.flint.spark._
import org.opensearch.flint.spark.FlintSparkIndex._
import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE}
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy
import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy
import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy
Expand Down Expand Up @@ -188,6 +189,25 @@ object FlintSparkSkippingIndex {
this
}

/**
* Add bloom filter skipping index column.
*
* @param colName
* indexed column name
* @param params
* bloom filter parameters
* @return
* index builder
*/
def addBloomFilter(colName: String, params: Map[String, String] = Map.empty): Builder = {
val col = findColumn(colName)
indexedColumns = indexedColumns :+ BloomFilterSkippingStrategy(
columnName = col.name,
columnType = col.dataType,
params = params)
this
}

override def buildIndex(): FlintSparkIndex =
new FlintSparkSkippingIndex(tableName, indexedColumns, indexOptions)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ object FlintSparkSkippingStrategy {
type SkippingKind = Value

// Use Value[s]Set because ValueSet already exists in Enumeration
val PARTITION, VALUE_SET, MIN_MAX = Value
val PARTITION, VALUE_SET, MIN_MAX, BLOOM_FILTER = Value
}

/** json4s doesn't serialize Enum by default */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.bloomfilter

import java.io.{InputStream, OutputStream}
import java.util.Locale

import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.Algorithm.{Algorithm, CLASSIC}

trait BloomFilter {

def algorithm: Algorithm

def bitSize(): Long

def put(item: Long): Boolean

def merge(bloomFilter: BloomFilter): BloomFilter

def mightContain(item: Long): Boolean

def writeTo(out: OutputStream): Unit
}

object BloomFilter {

object Algorithm extends Enumeration {
type Algorithm = Value
val CLASSIC = Value
}

val BLOOM_FILTER_ALGORITHM_KEY = "algorithm"
val DEFAULT_BLOOM_FILTER_ALGORITHM = CLASSIC.toString

class BloomFilterFactory(params: Map[String, String]) extends Serializable {

private val algorithm: Algorithm = {
val param = params.getOrElse(BLOOM_FILTER_ALGORITHM_KEY, DEFAULT_BLOOM_FILTER_ALGORITHM)
Algorithm.withName(param.toUpperCase(Locale.ROOT))
}

def parameters: Map[String, String] = {
algorithm match {
case CLASSIC => ClassicBloomFilter.getParameters(params) // TODO: add algorithm param
}
}

def create(): BloomFilter = {
algorithm match {
case CLASSIC => new ClassicBloomFilter(parameters)
}
}

def deserialize(in: InputStream): BloomFilter = {
algorithm match {
case CLASSIC => ClassicBloomFilter.deserialize(in)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.bloomfilter

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.BloomFilterFactory

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.types.{BinaryType, DataType}

/**
* Aggregate function that build bloom filter and serialize to binary as result. Copy from Spark
* built-in BloomFilterAggregate because it: 1) it accepts number of bits as argument instead of
* FPP 2) it calls static method BloomFilter.create and thus cannot change to other implementation
* 3) it is a Scala case class that cannot be extend and overridden
*
* @param child
* child expression of
* @param bloomFilter
* @param mutableAggBufferOffset
* @param inputAggBufferOffset
*/
case class BloomFilterAgg(
child: Expression,
bloomFilterFactory: BloomFilterFactory,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[BloomFilter] {

def this(child: Expression, bloomFilterFactory: BloomFilterFactory) = {
this(child, bloomFilterFactory, 0, 0)
}

override def nullable: Boolean = true

override def dataType: DataType = BinaryType

override def children: Seq[Expression] = Seq(child)

override def createAggregationBuffer(): BloomFilter = bloomFilterFactory.create()

override def update(buffer: BloomFilter, inputRow: InternalRow): BloomFilter = {
val value = child.eval(inputRow)
// Ignore null values.
if (value == null) {
return buffer
}
buffer.put(value.asInstanceOf[Long])
buffer
}

override def merge(buffer: BloomFilter, input: BloomFilter): BloomFilter = {
buffer.merge(input)
buffer
}

override def eval(buffer: BloomFilter): Any = {
if (buffer.bitSize() == 0) {
// There's no set bit in the Bloom filter and hence no not-null value is processed.
return null
}
serialize(buffer)
}

override def serialize(buffer: BloomFilter): Array[Byte] = {
// BloomFilterImpl.writeTo() writes 2 integers (version number and num hash functions), hence
// the +8
val size = (buffer.bitSize() / 8) + 8
require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size")
val out = new ByteArrayOutputStream(size.intValue())
buffer.writeTo(out)
out.close()
out.toByteArray
}

override def deserialize(bytes: Array[Byte]): BloomFilter = {
val in = new ByteArrayInputStream(bytes)
val bloomFilter = bloomFilterFactory.deserialize(in)
in.close()
bloomFilter
}

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression =
copy(child = newChildren.head)

override def withNewMutableAggBufferOffset(newOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newOffset)

override def withNewInputAggBufferOffset(newOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newOffset)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.bloomfilter

import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, SkippingKind}
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.BloomFilterFactory

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.functions.{col, xxhash64}

/**
* Skipping strategy based on approximate data structure bloom filter.
*/
case class BloomFilterSkippingStrategy(
override val kind: SkippingKind = BLOOM_FILTER,
override val columnName: String,
override val columnType: String,
params: Map[String, String] = Map.empty)
extends FlintSparkSkippingStrategy {

private val bloomFilterFactory = new BloomFilterFactory(params)

override val parameters: Map[String, String] = bloomFilterFactory.parameters

override def outputSchema(): Map[String, String] = Map(columnName -> "binary") // TODO: binary?

override def getAggregators: Seq[Expression] = {
Seq(
new BloomFilterAgg(xxhash64(col(columnName)).expr, bloomFilterFactory)
.toAggregateExpression())
}

override def rewritePredicate(predicate: Expression): Option[Expression] = None
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.bloomfilter
import java.io.{InputStream, OutputStream}

import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.Algorithm.{Algorithm, CLASSIC}
import org.opensearch.flint.spark.skipping.bloomfilter.ClassicBloomFilter.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY}

class ClassicBloomFilter(val delegate: org.apache.spark.util.sketch.BloomFilter)
extends BloomFilter
with Serializable {

def this(params: Map[String, String]) = {
this(
org.apache.spark.util.sketch.BloomFilter
.create(
params(CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY).toLong,
params(CLASSIC_BLOOM_FILTER_FPP_KEY).toDouble))
}

override def algorithm: Algorithm = CLASSIC

override def bitSize(): Long = delegate.bitSize()

override def put(item: Long): Boolean = delegate.putLong(item)

override def merge(bloomFilter: BloomFilter): BloomFilter = {
delegate.mergeInPlace(bloomFilter.asInstanceOf[ClassicBloomFilter].delegate)
this
}

override def mightContain(item: Long): Boolean = delegate.mightContainLong(item)

override def writeTo(out: OutputStream): Unit = delegate.writeTo(out)
}

object ClassicBloomFilter {

val CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY = "num_items"
val DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS = 10000

val CLASSIC_BLOOM_FILTER_FPP_KEY = "fpp"
val DEFAULT_CLASSIC_BLOOM_FILTER_FPP = 0.01

def getParameters(params: Map[String, String]): Map[String, String] = {
val map = Map.newBuilder[String, String]
map ++= params

if (!params.contains(CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY)) {
map += (CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS.toString)
}
if (!params.contains(CLASSIC_BLOOM_FILTER_FPP_KEY)) {
map += (CLASSIC_BLOOM_FILTER_FPP_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_FPP.toString)
}
map.result()
}

def deserialize(in: InputStream): BloomFilter = {
val delegate = org.apache.spark.util.sketch.BloomFilter.readFrom(in)
new ClassicBloomFilter(delegate)
}
}
Loading

0 comments on commit 8c6d35a

Please sign in to comment.