Skip to content

Commit

Permalink
Add UT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Feb 9, 2024
1 parent d7565f5 commit 1cc7cf1
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Expression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.{BooleanType, DataType}
import org.apache.spark.sql.types._

/**
* Bloom filter function that returns the membership check result for values of `valueExpression`
Expand All @@ -42,7 +42,18 @@ case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpre

override def symbol: String = "BLOOM_FILTER_MIGHT_CONTAIN"

override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) |
(BinaryType, LongType) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"""
| Input to function $prettyName should be Binary expression followed by a Long value,
| but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}].
| """.stripMargin)
}
}

override protected def withNewChildrenInternal(
newBloomFilterExpression: Expression,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping.bloomfilter

import org.apache.spark.FlintSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.types.{BinaryType, DoubleType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class BloomFilterMightContainSuite extends FlintSuite {

test("checkInputDataTypes should succeed for valid input types") {
val binaryExpression = Literal(Array[Byte](1, 2, 3), BinaryType)
val longExpression = Literal(42L, LongType)

val bloomFilterMightContain = BloomFilterMightContain(binaryExpression, longExpression)
assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckSuccess)
}

test("checkInputDataTypes should succeed for valid input types with nulls") {
val binaryExpression = Literal.create(null, BinaryType)
val longExpression = Literal.create(null, LongType)

val bloomFilterMightContain = BloomFilterMightContain(binaryExpression, longExpression)
assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckSuccess)
}

test("checkInputDataTypes should fail for invalid input types") {
val stringExpression = Literal(UTF8String.fromString("invalid"), StringType)
val doubleExpression = Literal(3.14, DoubleType)

val bloomFilterMightContain = BloomFilterMightContain(stringExpression, doubleExpression)
val expectedErrorMsg =
s"""
| Input to function bloom_filter_might_contain should be Binary expression followed by a Long value,
| but it's [${stringExpression.dataType.catalogString}, ${doubleExpression.dataType.catalogString}].
| """.stripMargin

assert(bloomFilterMightContain.checkInputDataTypes() == TypeCheckFailure(expectedErrorMsg))
}
}

0 comments on commit 1cc7cf1

Please sign in to comment.