-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add bloom filter, skipping strategy and aggregator
Signed-off-by: Chen Dai <[email protected]>
- Loading branch information
Showing
8 changed files
with
336 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
...egration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilter.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.flint.spark.skipping.bloomfilter | ||
|
||
import java.io.{InputStream, OutputStream} | ||
import java.util.Locale | ||
|
||
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.Algorithm.{Algorithm, CLASSIC} | ||
|
||
trait BloomFilter { | ||
|
||
def algorithm: Algorithm | ||
|
||
def bitSize(): Long | ||
|
||
def put(item: Long): Boolean | ||
|
||
def merge(bloomFilter: BloomFilter): BloomFilter | ||
|
||
def mightContain(item: Long): Boolean | ||
|
||
def writeTo(out: OutputStream): Unit | ||
} | ||
|
||
object BloomFilter { | ||
|
||
object Algorithm extends Enumeration { | ||
type Algorithm = Value | ||
val CLASSIC = Value | ||
} | ||
|
||
val BLOOM_FILTER_ALGORITHM_KEY = "algorithm" | ||
val DEFAULT_BLOOM_FILTER_ALGORITHM = CLASSIC.toString | ||
|
||
class BloomFilterFactory(params: Map[String, String]) extends Serializable { | ||
|
||
private val algorithm: Algorithm = { | ||
val param = params.getOrElse(BLOOM_FILTER_ALGORITHM_KEY, DEFAULT_BLOOM_FILTER_ALGORITHM) | ||
Algorithm.withName(param.toUpperCase(Locale.ROOT)) | ||
} | ||
|
||
def parameters: Map[String, String] = { | ||
algorithm match { | ||
case CLASSIC => ClassicBloomFilter.getParameters(params) // TODO: add algorithm param | ||
} | ||
} | ||
|
||
def create(): BloomFilter = { | ||
algorithm match { | ||
case CLASSIC => new ClassicBloomFilter(parameters) | ||
} | ||
} | ||
|
||
def deserialize(in: InputStream): BloomFilter = { | ||
algorithm match { | ||
case CLASSIC => ClassicBloomFilter.deserialize(in) | ||
} | ||
} | ||
} | ||
} |
98 changes: 98 additions & 0 deletions
98
...ation/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterAgg.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.flint.spark.skipping.bloomfilter | ||
|
||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream} | ||
|
||
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.BloomFilterFactory | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.Expression | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} | ||
import org.apache.spark.sql.types.{BinaryType, DataType} | ||
|
||
/** | ||
* Aggregate function that build bloom filter and serialize to binary as result. Copy from Spark | ||
* built-in BloomFilterAggregate because it: 1) it accepts number of bits as argument instead of | ||
* FPP 2) it calls static method BloomFilter.create and thus cannot change to other implementation | ||
* 3) it is a Scala case class that cannot be extend and overridden | ||
* | ||
* @param child | ||
* child expression of | ||
* @param bloomFilter | ||
* @param mutableAggBufferOffset | ||
* @param inputAggBufferOffset | ||
*/ | ||
case class BloomFilterAgg( | ||
child: Expression, | ||
bloomFilterFactory: BloomFilterFactory, | ||
override val mutableAggBufferOffset: Int, | ||
override val inputAggBufferOffset: Int) | ||
extends TypedImperativeAggregate[BloomFilter] { | ||
|
||
def this(child: Expression, bloomFilterFactory: BloomFilterFactory) = { | ||
this(child, bloomFilterFactory, 0, 0) | ||
} | ||
|
||
override def nullable: Boolean = true | ||
|
||
override def dataType: DataType = BinaryType | ||
|
||
override def children: Seq[Expression] = Seq(child) | ||
|
||
override def createAggregationBuffer(): BloomFilter = bloomFilterFactory.create() | ||
|
||
override def update(buffer: BloomFilter, inputRow: InternalRow): BloomFilter = { | ||
val value = child.eval(inputRow) | ||
// Ignore null values. | ||
if (value == null) { | ||
return buffer | ||
} | ||
buffer.put(value.asInstanceOf[Long]) | ||
buffer | ||
} | ||
|
||
override def merge(buffer: BloomFilter, input: BloomFilter): BloomFilter = { | ||
buffer.merge(input) | ||
buffer | ||
} | ||
|
||
override def eval(buffer: BloomFilter): Any = { | ||
if (buffer.bitSize() == 0) { | ||
// There's no set bit in the Bloom filter and hence no not-null value is processed. | ||
return null | ||
} | ||
serialize(buffer) | ||
} | ||
|
||
override def serialize(buffer: BloomFilter): Array[Byte] = { | ||
// BloomFilterImpl.writeTo() writes 2 integers (version number and num hash functions), hence | ||
// the +8 | ||
val size = (buffer.bitSize() / 8) + 8 | ||
require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size") | ||
val out = new ByteArrayOutputStream(size.intValue()) | ||
buffer.writeTo(out) | ||
out.close() | ||
out.toByteArray | ||
} | ||
|
||
override def deserialize(bytes: Array[Byte]): BloomFilter = { | ||
val in = new ByteArrayInputStream(bytes) | ||
val bloomFilter = bloomFilterFactory.deserialize(in) | ||
in.close() | ||
bloomFilter | ||
} | ||
|
||
override protected def withNewChildrenInternal( | ||
newChildren: IndexedSeq[Expression]): Expression = | ||
copy(child = newChildren.head) | ||
|
||
override def withNewMutableAggBufferOffset(newOffset: Int): ImperativeAggregate = | ||
copy(mutableAggBufferOffset = newOffset) | ||
|
||
override def withNewInputAggBufferOffset(newOffset: Int): ImperativeAggregate = | ||
copy(inputAggBufferOffset = newOffset) | ||
} |
38 changes: 38 additions & 0 deletions
38
...n/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.flint.spark.skipping.bloomfilter | ||
|
||
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy | ||
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, SkippingKind} | ||
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.BloomFilterFactory | ||
|
||
import org.apache.spark.sql.catalyst.expressions.Expression | ||
import org.apache.spark.sql.functions.{col, xxhash64} | ||
|
||
/** | ||
* Skipping strategy based on approximate data structure bloom filter. | ||
*/ | ||
case class BloomFilterSkippingStrategy( | ||
override val kind: SkippingKind = BLOOM_FILTER, | ||
override val columnName: String, | ||
override val columnType: String, | ||
params: Map[String, String] = Map.empty) | ||
extends FlintSparkSkippingStrategy { | ||
|
||
private val bloomFilterFactory = new BloomFilterFactory(params) | ||
|
||
override val parameters: Map[String, String] = bloomFilterFactory.parameters | ||
|
||
override def outputSchema(): Map[String, String] = Map(columnName -> "binary") // TODO: binary? | ||
|
||
override def getAggregators: Seq[Expression] = { | ||
Seq( | ||
new BloomFilterAgg(xxhash64(col(columnName)).expr, bloomFilterFactory) | ||
.toAggregateExpression()) | ||
} | ||
|
||
override def rewritePredicate(predicate: Expression): Option[Expression] = None | ||
} |
65 changes: 65 additions & 0 deletions
65
...n/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/ClassicBloomFilter.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.flint.spark.skipping.bloomfilter | ||
import java.io.{InputStream, OutputStream} | ||
|
||
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.Algorithm.{Algorithm, CLASSIC} | ||
import org.opensearch.flint.spark.skipping.bloomfilter.ClassicBloomFilter.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY} | ||
|
||
class ClassicBloomFilter(val delegate: org.apache.spark.util.sketch.BloomFilter) | ||
extends BloomFilter | ||
with Serializable { | ||
|
||
def this(params: Map[String, String]) = { | ||
this( | ||
org.apache.spark.util.sketch.BloomFilter | ||
.create( | ||
params(CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY).toLong, | ||
params(CLASSIC_BLOOM_FILTER_FPP_KEY).toDouble)) | ||
} | ||
|
||
override def algorithm: Algorithm = CLASSIC | ||
|
||
override def bitSize(): Long = delegate.bitSize() | ||
|
||
override def put(item: Long): Boolean = delegate.putLong(item) | ||
|
||
override def merge(bloomFilter: BloomFilter): BloomFilter = { | ||
delegate.mergeInPlace(bloomFilter.asInstanceOf[ClassicBloomFilter].delegate) | ||
this | ||
} | ||
|
||
override def mightContain(item: Long): Boolean = delegate.mightContainLong(item) | ||
|
||
override def writeTo(out: OutputStream): Unit = delegate.writeTo(out) | ||
} | ||
|
||
object ClassicBloomFilter { | ||
|
||
val CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY = "num_items" | ||
val DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS = 10000 | ||
|
||
val CLASSIC_BLOOM_FILTER_FPP_KEY = "fpp" | ||
val DEFAULT_CLASSIC_BLOOM_FILTER_FPP = 0.01 | ||
|
||
def getParameters(params: Map[String, String]): Map[String, String] = { | ||
val map = Map.newBuilder[String, String] | ||
map ++= params | ||
|
||
if (!params.contains(CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY)) { | ||
map += (CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS.toString) | ||
} | ||
if (!params.contains(CLASSIC_BLOOM_FILTER_FPP_KEY)) { | ||
map += (CLASSIC_BLOOM_FILTER_FPP_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_FPP.toString) | ||
} | ||
map.result() | ||
} | ||
|
||
def deserialize(in: InputStream): BloomFilter = { | ||
val delegate = org.apache.spark.util.sketch.BloomFilter.readFrom(in) | ||
new ClassicBloomFilter(delegate) | ||
} | ||
} |
Oops, something went wrong.