diff --git a/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/ClassicBloomFilter.java b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/ClassicBloomFilter.java index f6444fb09..c81d5fb78 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/ClassicBloomFilter.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/ClassicBloomFilter.java @@ -132,16 +132,23 @@ public void writeTo(OutputStream out) throws IOException { * @param in input stream * @return bloom filter */ - public static BloomFilter readFrom(InputStream in) throws IOException { - DataInputStream dis = new DataInputStream(in); + public static BloomFilter readFrom(InputStream in) { + try { + DataInputStream dis = new DataInputStream(in); + + // Check version compatibility + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IllegalStateException("Unexpected Bloom filter version number (" + version + ")"); + } - int version = dis.readInt(); - if (version != Version.V1.getVersionNumber()) { - throw new IOException("Unexpected Bloom filter version number (" + version + ")"); + // Read bloom filter content + int numHashFunctions = dis.readInt(); + BitArray bits = BitArray.readFrom(dis); + return new ClassicBloomFilter(bits, numHashFunctions); + } catch (IOException e) { + throw new RuntimeException(e); } - int numHashFunctions = dis.readInt(); - BitArray bits = BitArray.readFrom(dis); - return new ClassicBloomFilter(bits, numHashFunctions); } private static int optimalNumOfHashFunctions(long n, long m) { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContain.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContain.scala new file mode 100644 index 000000000..45062b29f --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterMightContain.scala @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.skipping.bloomfilter + +import java.io.ByteArrayInputStream + +import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Expression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.{BooleanType, DataType} + +/** + * Bloom filter function that returns the membership check result for values of `valueExpression` + * in the bloom filter represented by `bloomFilterExpression`. + * + * @param bloomFilterExpression + * binary expression that represents bloom filter data + * @param valueExpression + * Long value expression to be tested + */ +case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpression: Expression) + extends BinaryComparison { + + override def nullable: Boolean = true + + override def left: Expression = bloomFilterExpression + + override def right: Expression = valueExpression + + override def prettyName: String = "bloom_filter_might_contain" + + override def dataType: DataType = BooleanType + + override def symbol: String = "BLOOM_FILTER_MIGHT_CONTAIN" + + override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess + + override protected def withNewChildrenInternal( + newBloomFilterExpression: Expression, + newValueExpression: Expression): BloomFilterMightContain = + copy(bloomFilterExpression = newBloomFilterExpression, valueExpression = newValueExpression) + + override def eval(input: InternalRow): Any = { + val value = valueExpression.eval(input) + if (value == null) { + null + } else { + val bytes = bloomFilterExpression.eval().asInstanceOf[Array[Byte]] + val bloomFilter = ClassicBloomFilter.readFrom(new ByteArrayInputStream(bytes)) + bloomFilter.mightContain(value.asInstanceOf[Long]) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val bloomFilterEncoder = classOf[ClassicBloomFilter].getCanonicalName.stripSuffix("$") + val bf = s"$bloomFilterEncoder.readFrom(new java.io.ByteArrayInputStream(${leftGen.value}))" + val result = s"$bf.mightContain(${rightGen.value})" + val resultCode = + s""" + |if (!(${rightGen.isNull})) { + | ${leftGen.code} + | ${ev.isNull} = false; + | ${ev.value} = $result; + |} + """.stripMargin + ev.copy(code = code""" + ${rightGen.code} + boolean ${ev.isNull} = true; + boolean ${ev.value} = false; + $resultCode""") + } +} + +object BloomFilterMightContain { + + /** + * Generate bloom filter might contain function given the bloom filter column and value. + * + * @param colName + * column name + * @param value + * value + * @return + * bloom filter might contain expression + */ + def bloom_filter_might_contain(colName: String, value: Any): Column = { + new Column(BloomFilterMightContain(col(colName).expr, lit(value).expr)) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala index 73b03ef0f..2f583813e 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala @@ -9,7 +9,8 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, SkippingKind} import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, DEFAULT_CLASSIC_BLOOM_FILTER_FPP, DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} import org.apache.spark.sql.functions.{col, xxhash64} /** @@ -37,7 +38,13 @@ case class BloomFilterSkippingStrategy( ) // TODO: use xxhash64() for now } - override def rewritePredicate(predicate: Expression): Option[Expression] = None + override def rewritePredicate(predicate: Expression): Option[Expression] = { + predicate match { + case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) => + Some(BloomFilterMightContain(col(columnName).expr, xxhash64(new Column(value)).expr)) + case _ => None + } + } private def expectedNumItems: Int = { params diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index e68efdb7e..789b07c0c 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -11,6 +11,7 @@ import org.opensearch.flint.core.FlintVersion.current import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.skipping.FlintSparkSkippingFileIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName +import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain.bloom_filter_might_contain import org.scalatest.matchers.{Matcher, MatchResult} import org.scalatest.matchers.must.Matchers._ import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -19,7 +20,7 @@ import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.flint.config.FlintSparkConf._ -import org.apache.spark.sql.functions.{col, isnull} +import org.apache.spark.sql.functions.{col, isnull, lit, xxhash64} class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { @@ -337,7 +338,22 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { // Assert index data flint.queryIndex(testIndex).collect() should have size 2 - // TODO: Assert query rewrite result + // Assert query result and rewrite + def assertQueryRewrite(): Unit = { + val query = sql(s"SELECT name FROM $testTable WHERE age = 50") + checkAnswer(query, Row("Java")) + query.queryExecution.executedPlan should + useFlintSparkSkippingFileIndex( + hasIndexFilter(bloom_filter_might_contain("age", xxhash64(lit(50))))) + } + + // Test by default whole stage codegen + assertQueryRewrite() + + // Test by evaluation + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + assertQueryRewrite() + } } test("should rewrite applicable query with table name without database specified") {