Skip to content

Commit

Permalink
Add UT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Feb 2, 2024
1 parent 8c6d35a commit 4c5bffe
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,50 +10,126 @@ import java.util.Locale

import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.Algorithm.{Algorithm, CLASSIC}

/**
* Bloom filter interface inspired by [[org.apache.spark.util.sketch.BloomFilter]] but adapts to
* Flint skipping index use and remove unnecessary API for now.
*/
trait BloomFilter {

/**
* @return
* algorithm kind
*/
def algorithm: Algorithm

/**
* @return
* the number of bits in the underlying bit array.
*/
def bitSize(): Long

/**
* Put an item into this bloom filter.
*
* @param item
* Long value item to insert
* @return
* true if bits changed which means the item must be first time added to the bloom filter.
* Otherwise, it maybe the first time or not.
*/
def put(item: Long): Boolean

/**
* Merge this bloom filter with another bloom filter.
*
* @param bloomFilter
* bloom filter to merge
* @return
* bloom filter after merged
*/
def merge(bloomFilter: BloomFilter): BloomFilter

/**
* @param item
* Long value item to check
* @return
* true if the item may exist in this bloom filter. Otherwise, it is definitely not exist.
*/
def mightContain(item: Long): Boolean

/**
* Serialize this bloom filter and write it to an output stream.
*
* @param out
* output stream to write
*/
def writeTo(out: OutputStream): Unit
}

object BloomFilter {

/**
* Bloom filter algorithm.
*/
object Algorithm extends Enumeration {
type Algorithm = Value
val CLASSIC = Value
}

/**
* Bloom filter algorithm parameter name and default value if not present.
*/
val BLOOM_FILTER_ALGORITHM_KEY = "algorithm"
val DEFAULT_BLOOM_FILTER_ALGORITHM = CLASSIC.toString

/**
* Bloom filter factory that instantiate concrete bloom filter implementation.
*
* @param params
* bloom filter algorithm parameters
*/
class BloomFilterFactory(params: Map[String, String]) extends Serializable {

/**
* Bloom filter algorithm specified in parameters.
*/
private val algorithm: Algorithm = {
val param = params.getOrElse(BLOOM_FILTER_ALGORITHM_KEY, DEFAULT_BLOOM_FILTER_ALGORITHM)
Algorithm.withName(param.toUpperCase(Locale.ROOT))
}

/**
* Get all bloom filter parameters used to store in index metadata.
*
* @return
* all bloom filter algorithm parameters including those not present but has default values.
*/
def parameters: Map[String, String] = {
algorithm match {
case CLASSIC => ClassicBloomFilter.getParameters(params) // TODO: add algorithm param
}
}

/**
* Create a concrete bloom filter according to the parameters.
*
* @return
* bloom filter instance
*/
def create(): BloomFilter = {
algorithm match {
case CLASSIC => new ClassicBloomFilter(parameters)
}
}

/**
* Deserialize to create the bloom filter.
*
* @param in
* input stream to read from
* @return
* bloom filter instance
*/
def deserialize(in: InputStream): BloomFilter = {
algorithm match {
case CLASSIC => ClassicBloomFilter.deserialize(in)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ case class BloomFilterSkippingStrategy(
params: Map[String, String] = Map.empty)
extends FlintSparkSkippingStrategy {

/** Bloom filter factory */
private val bloomFilterFactory = new BloomFilterFactory(params)

override val parameters: Map[String, String] = bloomFilterFactory.parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@ package org.opensearch.flint.spark.skipping.bloomfilter
import java.io.{InputStream, OutputStream}

import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.Algorithm.{Algorithm, CLASSIC}
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.BLOOM_FILTER_ALGORITHM_KEY
import org.opensearch.flint.spark.skipping.bloomfilter.ClassicBloomFilter.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY}

class ClassicBloomFilter(val delegate: org.apache.spark.util.sketch.BloomFilter)
/**
* Classic bloom filter implementation by reusing Spark built-in bloom filter.
*
* @param delegate
* Spark bloom filter instance
*/
case class ClassicBloomFilter(delegate: org.apache.spark.util.sketch.BloomFilter)
extends BloomFilter
with Serializable {

Expand Down Expand Up @@ -39,16 +46,31 @@ class ClassicBloomFilter(val delegate: org.apache.spark.util.sketch.BloomFilter)

object ClassicBloomFilter {

/**
* Expected number of unique items key and default value.
*/
val CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY = "num_items"
val DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS = 10000

/**
* False positive probability (FPP) key and default value.
*/
val CLASSIC_BLOOM_FILTER_FPP_KEY = "fpp"
val DEFAULT_CLASSIC_BLOOM_FILTER_FPP = 0.01

/**
* @param params
* given parameters
* @return
* all parameters including those not present but has default value
*/
def getParameters(params: Map[String, String]): Map[String, String] = {
val map = Map.newBuilder[String, String]
map ++= params

if (!params.contains(BLOOM_FILTER_ALGORITHM_KEY)) {
map += (BLOOM_FILTER_ALGORITHM_KEY -> CLASSIC.toString)
}
if (!params.contains(CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY)) {
map += (CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS.toString)
}
Expand All @@ -58,6 +80,14 @@ object ClassicBloomFilter {
map.result()
}

/**
* Deserialize and instantiate a classic bloom filter instance.
*
* @param in
* input stream to read from
* @return
* classic bloom filter instance
*/
def deserialize(in: InputStream): BloomFilter = {
val delegate = org.apache.spark.util.sketch.BloomFilter.readFrom(in)
new ClassicBloomFilter(delegate)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.bloomfilter

import org.opensearch.flint.spark.skipping.{FlintSparkSkippingStrategy, FlintSparkSkippingStrategySuite}
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.Algorithm.CLASSIC
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.BLOOM_FILTER_ALGORITHM_KEY
import org.opensearch.flint.spark.skipping.bloomfilter.ClassicBloomFilter.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, DEFAULT_CLASSIC_BLOOM_FILTER_FPP, DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.FlintSuite

class BloomFilterSkippingStrategySuite
extends FlintSuite
with FlintSparkSkippingStrategySuite
with Matchers {

/** Subclass initializes strategy class to test */
override val strategy: FlintSparkSkippingStrategy =
BloomFilterSkippingStrategy(columnName = "name", columnType = "string")

test("parameters") {
strategy.parameters should contain allOf (BLOOM_FILTER_ALGORITHM_KEY -> CLASSIC.toString,
CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS.toString,
CLASSIC_BLOOM_FILTER_FPP_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_FPP.toString)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.bloomfilter

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.{BLOOM_FILTER_ALGORITHM_KEY, BloomFilterFactory}
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilter.Algorithm.CLASSIC
import org.opensearch.flint.spark.skipping.bloomfilter.ClassicBloomFilter.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, DEFAULT_CLASSIC_BLOOM_FILTER_FPP, DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.FlintSuite

class ClassicBloomFilterSuite extends FlintSuite with Matchers {

test("parameters should return all parameters including defaults") {
val factory = new BloomFilterFactory(Map(BLOOM_FILTER_ALGORITHM_KEY -> CLASSIC.toString))

factory.parameters should contain allOf (BLOOM_FILTER_ALGORITHM_KEY -> CLASSIC.toString,
CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS.toString,
CLASSIC_BLOOM_FILTER_FPP_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_FPP.toString)
}

test("parameters should return all specified parameters") {
val expectedNumItems = 50000
val fpp = 0.001
val factory = new BloomFilterFactory(
Map(
BLOOM_FILTER_ALGORITHM_KEY -> CLASSIC.toString,
CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY -> expectedNumItems.toString,
CLASSIC_BLOOM_FILTER_FPP_KEY -> fpp.toString))

factory.parameters should contain allOf (BLOOM_FILTER_ALGORITHM_KEY -> CLASSIC.toString,
CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY -> expectedNumItems.toString,
CLASSIC_BLOOM_FILTER_FPP_KEY -> fpp.toString)
}

test("serialize and deserialize") {
val factory = new BloomFilterFactory(Map(BLOOM_FILTER_ALGORITHM_KEY -> CLASSIC.toString))
val bloomFilter = factory.create()
bloomFilter.put(1L)
bloomFilter.put(2L)
bloomFilter.put(3L)

// Serialize and then deserialize should remain the same
val out = new ByteArrayOutputStream()
bloomFilter.writeTo(out)
val in = new ByteArrayInputStream(out.toByteArray)
val newBloomFilter = factory.deserialize(in)
bloomFilter shouldBe newBloomFilter
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
| "num_items": "10000",
| "fpp": "0.01"
| },
| "columnName": "age",
| "columnType": "binary"
| "columnName": "name",
| "columnType": "string"
| }],
| "source": "spark_catalog.default.test",
| "options": { "auto_refresh": "false" },
Expand Down

0 comments on commit 4c5bffe

Please sign in to comment.