Skip to content

Commit

Permalink
Add bloom filter might contain expression
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Feb 7, 2024
1 parent b1d132e commit 6f2aceb
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,23 @@ public void writeTo(OutputStream out) throws IOException {
* @param in input stream
* @return bloom filter
*/
public static BloomFilter readFrom(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);
public static BloomFilter readFrom(InputStream in) {
try {
DataInputStream dis = new DataInputStream(in);

// Check version compatibility
int version = dis.readInt();
if (version != Version.V1.getVersionNumber()) {
throw new IllegalStateException("Unexpected Bloom filter version number (" + version + ")");
}

int version = dis.readInt();
if (version != Version.V1.getVersionNumber()) {
throw new IOException("Unexpected Bloom filter version number (" + version + ")");
// Read bloom filter content
int numHashFunctions = dis.readInt();
BitArray bits = BitArray.readFrom(dis);
return new ClassicBloomFilter(bits, numHashFunctions);
} catch (IOException e) {
throw new RuntimeException(e);
}
int numHashFunctions = dis.readInt();
BitArray bits = BitArray.readFrom(dis);
return new ClassicBloomFilter(bits, numHashFunctions);
}

private static int optimalNumOfHashFunctions(long n, long m) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.bloomfilter

import java.io.ByteArrayInputStream

import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Expression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.{BooleanType, DataType}

/**
* Bloom filter function that returns the membership check result for values of `valueExpression`
* in the bloom filter represented by `bloomFilterExpression`.
*
* @param bloomFilterExpression
* binary expression that represents bloom filter data
* @param valueExpression
* Long value expression to be tested
*/
case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpression: Expression)
extends BinaryComparison {

override def nullable: Boolean = true

override def left: Expression = bloomFilterExpression

override def right: Expression = valueExpression

override def prettyName: String = "bloom_filter_might_contain"

override def dataType: DataType = BooleanType

override def symbol: String = "BLOOM_FILTER_MIGHT_CONTAIN"

override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess

override protected def withNewChildrenInternal(
newBloomFilterExpression: Expression,
newValueExpression: Expression): BloomFilterMightContain =
copy(bloomFilterExpression = newBloomFilterExpression, valueExpression = newValueExpression)

override def eval(input: InternalRow): Any = {
val value = valueExpression.eval(input)
if (value == null) {
null
} else {
val bytes = bloomFilterExpression.eval().asInstanceOf[Array[Byte]]
val bloomFilter = ClassicBloomFilter.readFrom(new ByteArrayInputStream(bytes))
bloomFilter.mightContain(value.asInstanceOf[Long])
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val leftGen = left.genCode(ctx)
val rightGen = right.genCode(ctx)
val bloomFilterEncoder = classOf[ClassicBloomFilter].getCanonicalName.stripSuffix("$")
val bf = s"$bloomFilterEncoder.readFrom(new java.io.ByteArrayInputStream(${leftGen.value}))"
val result = s"$bf.mightContain(${rightGen.value})"
val resultCode =
s"""
|if (!(${rightGen.isNull})) {
| ${leftGen.code}
| ${ev.isNull} = false;
| ${ev.value} = $result;
|}
""".stripMargin
ev.copy(code = code"""
${rightGen.code}
boolean ${ev.isNull} = true;
boolean ${ev.value} = false;
$resultCode""")
}
}

object BloomFilterMightContain {

/**
* Generate bloom filter might contain function given the bloom filter column and value.
*
* @param colName
* column name
* @param value
* value
* @return
* bloom filter might contain expression
*/
def bloom_filter_might_contain(colName: String, value: Any): Column = {
new Column(BloomFilterMightContain(col(colName).expr, lit(value).expr))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, SkippingKind}
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, DEFAULT_CLASSIC_BLOOM_FILTER_FPP, DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS}

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.functions.{col, xxhash64}

/**
Expand Down Expand Up @@ -37,7 +38,13 @@ case class BloomFilterSkippingStrategy(
) // TODO: use xxhash64() for now
}

override def rewritePredicate(predicate: Expression): Option[Expression] = None
override def rewritePredicate(predicate: Expression): Option[Expression] = {
predicate match {
case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
Some(BloomFilterMightContain(col(columnName).expr, xxhash64(new Column(value)).expr))
case _ => None
}
}

private def expectedNumItems: Int = {
params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.opensearch.flint.core.FlintVersion.current
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.skipping.FlintSparkSkippingFileIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain.bloom_filter_might_contain
import org.scalatest.matchers.{Matcher, MatchResult}
import org.scalatest.matchers.must.Matchers._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
Expand All @@ -19,7 +20,7 @@ import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.flint.config.FlintSparkConf._
import org.apache.spark.sql.functions.{col, isnull}
import org.apache.spark.sql.functions.{col, isnull, lit, xxhash64}

class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {

Expand Down Expand Up @@ -337,7 +338,22 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite {
// Assert index data
flint.queryIndex(testIndex).collect() should have size 2

// TODO: Assert query rewrite result
// Assert query result and rewrite
def assertQueryRewrite(): Unit = {
val query = sql(s"SELECT name FROM $testTable WHERE age = 50")
checkAnswer(query, Row("Java"))
query.queryExecution.executedPlan should
useFlintSparkSkippingFileIndex(
hasIndexFilter(bloom_filter_might_contain("age", xxhash64(lit(50)))))
}

// Test by default whole stage codegen
assertQueryRewrite()

// Test by evaluation
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
assertQueryRewrite()
}
}

test("should rewrite applicable query with table name without database specified") {
Expand Down

0 comments on commit 6f2aceb

Please sign in to comment.